Source code for xcube.core.mldataset.computed

# Copyright (c) 2018-2024 by xcube team and contributors
# Permissions are hereby granted under the terms of the MIT License:
# https://opensource.org/licenses/MIT.

import os.path
import uuid
from typing import Sequence, Any, Dict, Callable, Mapping, Optional, Tuple

import xarray as xr

from xcube.core.byoa import CodeConfig
from xcube.core.byoa import FileSet
from xcube.core.gridmapping import GridMapping
from xcube.util.assertions import assert_given
from xcube.util.assertions import assert_instance
from xcube.util.assertions import assert_true
from xcube.util.perf import measure_time
from .abc import MultiLevelDataset
from .lazy import LazyMultiLevelDataset

MultiLevelDatasetGetter = Callable[[str], MultiLevelDataset]
MultiLevelDatasetSetter = Callable[[MultiLevelDataset], None]


[docs] class ComputedMultiLevelDataset(LazyMultiLevelDataset): """A multi-level dataset whose level datasets are computed by a user function. The script can import other Python modules located in the same directory as *script_path*. """ def __init__( self, script_path: str, callable_name: str, input_ml_dataset_ids: Sequence[str], input_ml_dataset_getter: MultiLevelDatasetGetter, input_parameters: Optional[Mapping[str, Any]] = None, ds_id: str = "", exception_type: type = ValueError, ): callable_ref, callable_obj = self.get_callable( script_path, callable_name, input_ml_dataset_ids, input_ml_dataset_getter, input_parameters=input_parameters, ds_id=ds_id, exception_type=exception_type, ) super().__init__(ds_id=ds_id, parameters=input_parameters) self._callable_ref = callable_ref self._callable_obj = callable_obj self._input_ml_dataset_ids = input_ml_dataset_ids self._input_ml_dataset_getter = input_ml_dataset_getter self._exception_type = exception_type @classmethod def get_callable( cls, script_path: str, callable_name: str, input_ml_dataset_ids: Sequence[str], input_ml_dataset_getter: MultiLevelDatasetGetter, input_parameters: Optional[Mapping[str, Any]] = None, ds_id: str = "", exception_type: type = ValueError, ) -> Tuple[str, Callable]: assert_instance(script_path, str, name="script_path") assert_given(script_path, name="script_path") assert_true( callable(input_ml_dataset_getter), message=f"input_ml_dataset_getter must be a callable", ) assert_given(input_ml_dataset_getter, name="input_ml_dataset_getter") assert_instance(ds_id, str, name="ds_id") assert_given(ds_id, name="ds_id") module_name = None basename = os.path.basename(script_path) basename, ext = os.path.splitext(basename) if ext == ".py": script_path = os.path.dirname(script_path) module_name = basename if ":" in callable_name: callable_ref = callable_name else: if not module_name: raise exception_type( f"Invalid in-memory dataset descriptor {ds_id!r}:" f" Missing module name in {callable_name!r}" ) callable_ref = f"{module_name}:{callable_name}" if not input_ml_dataset_ids: raise exception_type( f"Invalid in-memory dataset descriptor {ds_id!r}:" f" Input dataset(s) missing for callable {callable_name!r}" ) for input_param_name in (input_parameters or {}).keys(): if not input_param_name or not input_param_name.isidentifier(): raise exception_type( f"Invalid in-memory dataset descriptor {ds_id!r}:" f" Input parameter {input_param_name!r}" f" for callable {callable_name!r}" f" is not a valid Python identifier" ) try: callable_obj = CodeConfig.from_file_set( FileSet(path=script_path), callable_ref=callable_ref, install_required=False, ).get_callable() except (TypeError, ValueError, ImportError) as e: raise exception_type(f"Invalid dataset descriptor {ds_id!r}: {e}") from e return callable_ref, callable_obj @property def num_inputs(self) -> int: return len(self._input_ml_dataset_ids) def get_input_dataset(self, index: int) -> MultiLevelDataset: return self._input_ml_dataset_getter(self._input_ml_dataset_ids[index]) def _get_num_levels_lazily(self) -> int: return self.get_input_dataset(0).num_levels def _get_grid_mapping_lazily(self) -> GridMapping: return self.get_input_dataset(0).grid_mapping def _get_dataset_lazily(self, index: int, parameters: Dict[str, Any]) -> xr.Dataset: input_datasets = [ self._input_ml_dataset_getter(ds_id).get_dataset(index) for ds_id in self._input_ml_dataset_ids ] try: with measure_time( tag=f"Computed in-memory dataset" f" {self.ds_id!r} at level {index}" ): computed_value = self._callable_obj(*input_datasets, **parameters) except Exception as e: raise self._exception_type( f"Failed to compute in-memory dataset {self.ds_id!r}" f" at level {index} " f"from function {self._callable_ref!r}(): {e}" ) from e if not isinstance(computed_value, xr.Dataset): raise self._exception_type( f"Failed to compute in-memory dataset {self.ds_id!r}" f" at level {index} " f"from function {self._callable_ref!r}(): " f"expected an xarray.Dataset but got {type(computed_value)}" ) return computed_value
def augment_ml_dataset( ml_dataset: MultiLevelDataset, script_path: str, callable_name: str, input_ml_dataset_getter: MultiLevelDatasetGetter, input_ml_dataset_setter: MultiLevelDatasetSetter, input_parameters: Optional[Mapping[str, Any]] = None, is_factory: bool = False, exception_type: type = ValueError, ): from .identity import IdentityMultiLevelDataset from .combined import CombinedMultiLevelDataset with measure_time(tag=f"Added augmentation from {script_path}"): orig_id = ml_dataset.ds_id aug_id = uuid.uuid4() aug_inp_id = f"aug-input-{aug_id}" aug_inp_ds = IdentityMultiLevelDataset(ml_dataset, ds_id=aug_inp_id) input_ml_dataset_setter(aug_inp_ds) aug_ds = _open_ml_dataset_from_python_code( script_path, callable_name, [aug_inp_id], input_ml_dataset_getter, input_parameters=input_parameters, is_factory=is_factory, ds_id=f"aug-{aug_id}", exception_type=exception_type, ) return CombinedMultiLevelDataset([ml_dataset, aug_ds], ds_id=orig_id) def open_ml_dataset_from_python_code( script_path: str, callable_name: str, input_ml_dataset_ids: Sequence[str], input_ml_dataset_getter: MultiLevelDatasetGetter, input_parameters: Optional[Mapping[str, Any]] = None, is_factory: bool = False, ds_id: str = "", exception_type: type = ValueError, ) -> MultiLevelDataset: with measure_time(tag=f"Opened memory dataset {script_path}"): return _open_ml_dataset_from_python_code( script_path, callable_name, input_ml_dataset_ids, input_ml_dataset_getter, input_parameters=input_parameters, is_factory=is_factory, ds_id=ds_id, exception_type=exception_type, ) def _open_ml_dataset_from_python_code( script_path: str, callable_name: str, input_ml_dataset_ids: Sequence[str], input_ml_dataset_getter: MultiLevelDatasetGetter, input_parameters: Optional[Mapping[str, Any]] = None, is_factory: bool = False, ds_id: str = "", exception_type: type = ValueError, ) -> MultiLevelDataset: if is_factory: callable_ref, callable_obj = ComputedMultiLevelDataset.get_callable( script_path, callable_name, input_ml_dataset_ids, input_ml_dataset_getter, input_parameters=input_parameters, ds_id=ds_id, exception_type=exception_type, ) input_datasets = [ input_ml_dataset_getter(ds_id) for ds_id in input_ml_dataset_ids ] try: ml_dataset = callable_obj(*input_datasets, **(input_parameters or {})) if not isinstance(ml_dataset, MultiLevelDataset): raise TypeError( f"{callable_ref!r} must return instance of" f" xcube.core.mldataset.MultiLevelDataset," f" but was {type(ml_dataset)}" ) ml_dataset.ds_id = ds_id return ml_dataset except BaseException as e: raise exception_type( f"Invalid in-memory dataset descriptor {ds_id!r}: {e}" ) from e else: return ComputedMultiLevelDataset( script_path, callable_name, input_ml_dataset_ids, input_ml_dataset_getter, input_parameters=input_parameters, ds_id=ds_id, exception_type=exception_type, )