Source code for xcube.core.maskset

# Copyright (c) 2018-2024 by xcube team and contributors
# Permissions are hereby granted under the terms of the MIT License:

from typing import Dict, Any, Iterable

import dask.array as da
import numpy as np
import xarray as xr

# TODO: this would be useful to have in xarray:
#       >>> ds = xr.open_dataset("my/path/to/cf/", decode_flags=True)
#       >>> ds.flag_mask_sets['quality_flags']

[docs] class MaskSet: """A set of mask variables derived from a variable *flag_var* with the following CF attributes: - One or both of `flag_masks` and `flag_values` - `flag_meanings` (always required) See for details on the use of these attributes. Each mask is represented by an `xarray.DataArray`, has the name of the flag, is of type `numpy.unit8`, and has the dimensions of the given *flag_var*. Args: flag_var: an `xarray.DataArray` that defines flag values. The CF attributes `flag_meanings` and one or both of `flag_masks` and `flag_values` are expected to exist and be valid. """ def __init__(self, flag_var: xr.DataArray): flag_masks = flag_var.attrs.get("flag_masks") flag_values = flag_var.attrs.get("flag_values") if flag_masks is None and flag_values is None: raise ValueError( "One or both of the attributes " "'flag_masks' or 'flag_values' " "must be present and non-null in flag_var" ) if flag_masks is not None: flag_masks = _convert_flag_var_attribute_value(flag_masks, "flag_masks") if flag_values is not None: flag_values = _convert_flag_var_attribute_value(flag_values, "flag_values") if "flag_meanings" not in flag_var.attrs: raise ValueError("flag_var must have the attribute 'flag_meanings'") flag_meanings = flag_var.attrs.get("flag_meanings") if not isinstance(flag_meanings, str): raise TypeError("attribute 'flag_meanings' of flag_var " "must be a string") flag_names = flag_meanings.split(" ") if flag_masks is not None and len(flag_names) != len(flag_masks): raise ValueError( "attributes 'flag_meanings' and 'flag_masks' " "are not corresponding" ) if flag_values is not None and len(flag_names) != len(flag_values): raise ValueError( "attributes 'flag_meanings' and 'flag_values' " "are not corresponding" ) if flag_masks is None: flag_masks = [None] * len(flag_names) if flag_values is None: flag_values = [None] * len(flag_names) self._flag_var = flag_var self._flag_names = flag_names self._flags = dict(zip(flag_names, list(zip(flag_masks, flag_values)))) self._masks = {} @classmethod def is_flag_var(cls, var: xr.DataArray) -> bool: return "flag_meanings" in var.attrs and ( "flag_masks" in var.attrs or "flag_values" in var.attrs )
[docs] @classmethod def get_mask_sets(cls, dataset: xr.Dataset) -> Dict[str, "MaskSet"]: """For each "flag" variable in given *dataset*, turn it into a ``MaskSet``, store it in a dictionary. Args: dataset: The dataset Returns: A mapping of flag names to ``MaskSet``. Will be empty if there are no flag variables in *dataset*. """ masks = {} for var_name in dataset.variables: var = dataset[var_name] if cls.is_flag_var(var): masks[var_name] = MaskSet(var) return masks
def _repr_html_(self): lines = [ "<html>", "<table>", "<tr><th>Flag name</th><th>Mask</th><th>Value</th></tr>", ] for name, data in self._flags.items(): mask, value = data lines.append(f"<tr><td>{name}</td><td>{mask}</td><td>{value}</td></tr>") lines.extend(["</table>", "</html>"]) return "\n".join(lines) def __str__(self): return "%s(%s)" % (, ", ".join(["%s=%s" % (n, v) for n, v in self._flags.items()]), ) def __dir__(self) -> Iterable[str]: return self._flag_names def __getattr__(self, name: str) -> Any: if name not in self._flags: raise AttributeError(name) return self.get_mask(name) def __getitem__(self, item): try: name = self._flag_names[item] if name not in self._flags: raise IndexError(item) except TypeError: name = item if name not in self._flags: raise KeyError(item) return self.get_mask(name) def __contains__(self, item): return item in self._flags def get_mask(self, flag_name: str): if flag_name not in self._flags: raise ValueError('invalid flag name "%s"' % flag_name) if flag_name in self._masks: return self._masks[flag_name] flag_var = self._flag_var flag_mask, flag_value = self._flags[flag_name] if flag_var.chunks is not None: ones_array = da.ones(flag_var.shape, dtype=np.uint8, chunks=flag_var.chunks) else: ones_array = np.ones(flag_var.shape, dtype=np.uint8) mask_var = xr.DataArray( ones_array, dims=flag_var.dims, name=flag_name, coords=flag_var.coords ) if flag_mask is not None: if flag_var.dtype != flag_mask.dtype: flag_var = flag_var.astype(flag_mask.dtype) if flag_value is not None: mask_var = mask_var.where((flag_var & flag_mask) == flag_value, 0) else: mask_var = mask_var.where((flag_var & flag_mask) != 0, 0) else: if flag_var.dtype != flag_value.dtype: flag_var = flag_var.astype(flag_value.dtype) mask_var = mask_var.where(flag_var == flag_value, 0) self._masks[flag_name] = mask_var return mask_var
_MASK_DTYPES = ( (2**8, np.uint8), (2**16, np.uint16), (2**32, np.uint32), (2**64, np.uint64), ) def _convert_flag_var_attribute_value(attr_value, attr_name): if isinstance(attr_value, str): err_msg = f'Invalid bit expression in value for {attr_name}: "{attr_value}"' masks = [] max_mask = 0 for s in attr_value.split(","): s = s.strip() pair = s.split("-") if len(pair) == 1: try: mask = (1 << int(s[0:-1])) if s.endswith("b") else int(s) except ValueError as e: raise ValueError(err_msg) from e elif len(pair) == 2: s1, s2 = pair if not s1.endswith("b") or not s2.endswith("b"): raise ValueError(err_msg) try: b1 = int(s1[0:-1]) b2 = int(s2[0:-1]) except ValueError as e: raise ValueError(err_msg) from e if b1 > b2: raise ValueError(err_msg) mask = 0 for b in range(b1, b2 + 1): mask |= 1 << b else: raise ValueError(err_msg) masks.append(mask) max_mask = max(max_mask, mask) for limit, dtype in _MASK_DTYPES: if max_mask <= limit: return np.array(masks, dtype) raise ValueError(err_msg) if not (hasattr(attr_value, "dtype") and hasattr(attr_value, "shape")): raise TypeError(f"attribute {attr_name!r} must be an integer array") return attr_value