# The MIT License (MIT)
# Copyright (c) 2019-2021 by the xcube development team and contributors
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Dict, Any, Iterable
import dask.array as da
import numpy as np
import xarray as xr
# TODO: this would be useful to have in xarray:
# >>> ds = xr.open_dataset("my/path/to/cf/netcdf.nc", decode_flags=True)
# >>> ds.flag_mask_sets['quality_flags']
[docs]class MaskSet:
"""
A set of mask variables derived from a variable *flag_var* with the following
CF attributes:
- One or both of `flag_masks` and `flag_values`
- `flag_meanings` (always required)
See https://cfconventions.org/Data/cf-conventions/cf-conventions-1.9/cf-conventions.html#flags
for details on the use of these attributes.
Each mask is represented by an `xarray.DataArray`, has the name of the flag,
is of type `numpy.unit8`, and has the dimensions of the given *flag_var*.
:param flag_var: an `xarray.DataArray` that defines flag values.
The CF attributes `flag_meanings` and one or both of
`flag_masks` and `flag_values` are expected to exist and be valid.
"""
def __init__(self, flag_var: xr.DataArray):
flag_masks = flag_var.attrs.get('flag_masks')
flag_values = flag_var.attrs.get('flag_values')
if flag_masks is None and flag_values is None:
raise ValueError("One or both of the attributes "
"'flag_masks' or 'flag_values' "
"must be present and non-null in flag_var")
if flag_masks is not None:
flag_masks = _convert_flag_var_attribute_value(flag_masks,
'flag_masks')
if flag_values is not None:
flag_values = _convert_flag_var_attribute_value(flag_values,
'flag_values')
if 'flag_meanings' not in flag_var.attrs:
raise ValueError(
"flag_var must have the attribute 'flag_meanings'")
flag_meanings = flag_var.attrs.get('flag_meanings')
if not isinstance(flag_meanings, str):
raise TypeError("attribute 'flag_meanings' of flag_var "
"must be a string")
flag_names = flag_meanings.split(' ')
if flag_masks is not None and len(flag_names) != len(flag_masks):
raise ValueError("attributes 'flag_meanings' and 'flag_masks' "
"are not corresponding")
if flag_values is not None and len(flag_names) != len(flag_values):
raise ValueError("attributes 'flag_meanings' and 'flag_values' "
"are not corresponding")
if flag_masks is None:
flag_masks = [None] * len(flag_names)
if flag_values is None:
flag_values = [None] * len(flag_names)
self._flag_var = flag_var
self._flag_names = flag_names
self._flags = dict(zip(flag_names, list(zip(flag_masks, flag_values))))
self._masks = {}
@classmethod
def is_flag_var(cls, var: xr.DataArray) -> bool:
return 'flag_meanings' in var.attrs \
and ('flag_masks' in var.attrs or 'flag_values' in var.attrs)
[docs] @classmethod
def get_mask_sets(cls, dataset: xr.Dataset) -> Dict[str, 'MaskSet']:
"""
For each "flag" variable in given *dataset*, turn it into a ``MaskSet``,
store it in a dictionary.
:param dataset: The dataset
:return: A mapping of flag names to ``MaskSet``. Will be empty if there
are no flag variables in *dataset*.
"""
masks = {}
for var_name in dataset.variables:
var = dataset[var_name]
if cls.is_flag_var(var):
masks[var_name] = MaskSet(var)
return masks
def _repr_html_(self):
lines = ['<html>',
'<table>',
'<tr><th>Flag name</th><th>Mask</th><th>Value</th></tr>']
for name, data in self._flags.items():
mask, value = data
lines.append(f'<tr><td>{name}</td><td>{mask}</td><td>{value}</td></tr>')
lines.extend(['</table>',
'</html>'])
return '\n'.join(lines)
def __str__(self):
return "%s(%s)" % (self._flag_var.name, ', '.join(["%s=%s" % (n, v) for n, v in self._flags.items()]))
def __dir__(self) -> Iterable[str]:
return self._flag_names
def __getattr__(self, name: str) -> Any:
if name not in self._flags:
raise AttributeError(name)
return self.get_mask(name)
def __getitem__(self, item):
try:
name = self._flag_names[item]
if name not in self._flags:
raise IndexError(item)
except TypeError:
name = item
if name not in self._flags:
raise KeyError(item)
return self.get_mask(name)
def __contains__(self, item):
return item in self._flags
def get_mask(self, flag_name: str):
if flag_name not in self._flags:
raise ValueError('invalid flag name "%s"' % flag_name)
if flag_name in self._masks:
return self._masks[flag_name]
flag_var = self._flag_var
flag_mask, flag_value = self._flags[flag_name]
if flag_var.chunks is not None:
ones_array = da.ones(flag_var.shape,
dtype=np.uint8,
chunks=flag_var.chunks)
else:
ones_array = np.ones(flag_var.shape,
dtype=np.uint8)
mask_var = xr.DataArray(ones_array,
dims=flag_var.dims,
name=flag_name,
coords=flag_var.coords)
if flag_mask is not None:
if flag_var.dtype != flag_mask.dtype:
flag_var = flag_var.astype(flag_mask.dtype)
if flag_value is not None:
mask_var = mask_var.where((flag_var & flag_mask) == flag_value, 0)
else:
mask_var = mask_var.where((flag_var & flag_mask) != 0, 0)
else:
if flag_var.dtype != flag_value.dtype:
flag_var = flag_var.astype(flag_value.dtype)
mask_var = mask_var.where(flag_var == flag_value, 0)
self._masks[flag_name] = mask_var
return mask_var
_MASK_DTYPES = (
(2 ** 8, np.uint8), (2 ** 16, np.uint16), (2 ** 32, np.uint32), (2 ** 64, np.uint64)
)
def _convert_flag_var_attribute_value(attr_value, attr_name):
if isinstance(attr_value, str):
err_msg = f'Invalid bit expression in value for {attr_name}: "{attr_value}"'
masks = []
max_mask = 0
for s in attr_value.split(","):
s = s.strip()
pair = s.split('-')
if len(pair) == 1:
try:
mask = (1 << int(s[0:-1])) if s.endswith("b") else int(s)
except ValueError as e:
raise ValueError(err_msg) from e
elif len(pair) == 2:
s1, s2 = pair
if not s1.endswith("b") or not s2.endswith("b"):
raise ValueError(err_msg)
try:
b1 = int(s1[0:-1])
b2 = int(s2[0:-1])
except ValueError as e:
raise ValueError(err_msg) from e
if b1 > b2:
raise ValueError(err_msg)
mask = 0
for b in range(b1, b2 + 1):
mask |= 1 << b
else:
raise ValueError(err_msg)
masks.append(mask)
max_mask = max(max_mask, mask)
for limit, dtype in _MASK_DTYPES:
if max_mask <= limit:
return np.array(masks, dtype)
raise ValueError(err_msg)
if not (hasattr(attr_value, 'dtype') and hasattr(attr_value, 'shape')):
raise TypeError(f'attribute {attr_name!r} must be an integer array')
return attr_value