Source code for xcube.core.level

import os
from typing import List, Callable, Sequence, Optional, Tuple

import xarray as xr

from xcube.core.dsio import open_dataset

PyramidLevelCallback = Callable[[xr.Dataset, int, int], Optional[xr.Dataset]]


[docs]def compute_levels(dataset: xr.Dataset, spatial_dims: Tuple[str, str] = None, spatial_shape: Tuple[int, int] = None, spatial_tile_shape: Tuple[int, int] = None, var_names: Sequence[str] = None, num_levels_max: int = None, post_process_level: PyramidLevelCallback = None, progress_monitor: PyramidLevelCallback = None) -> List[xr.Dataset]: """ Transform the given *dataset* into the levels of a multi-level pyramid with spatial resolution decreasing by a factor of two in both spatial dimensions. It is assumed that the spatial dimensions of each variable are the inner-most, that is, the last two elements of a variable's shape provide the spatial dimension sizes. :param dataset: The input dataset to be turned into a multi-level pyramid. :param spatial_dims: If given, only variables are considered whose last to dimension elements match the given *spatial_dims*. :param spatial_shape: If given, only variables are considered whose last to shape elements match the given *spatial_shape*. :param spatial_tile_shape: If given, chunking will match the provided *spatial_tile_shape*. :param var_names: Variables to consider. If None, all variables with at least two dimensions are considered. :param num_levels_max: If given, the maximum number of pyramid levels. :param post_process_level: If given, the function will be called for each level and must return a dataset. :param progress_monitor: If given, the function will be called for each level. :return: A list of dataset instances representing the multi-level pyramid. """ dropped_vars, spatial_shape, spatial_tile_shape = _filter_level_source_dataset(dataset, var_names, spatial_dims, spatial_shape, spatial_tile_shape) if dropped_vars: dataset = dataset.drop(dropped_vars) if not tuple(dataset.data_vars): raise ValueError("cannot compute pyramid levels because no suitable data variables were found") if spatial_tile_shape is None: spatial_tile_shape = min(spatial_shape[0], 512), min(spatial_shape[1], 512) # Count num_levels level_shapes = _compute_level_shapes(spatial_shape, spatial_tile_shape, num_levels_max=num_levels_max) num_levels = len(level_shapes) # Compute levels level_dataset = dataset level_datasets = [] for level in range(num_levels): if level > 0: # Down-sample levels downsampled_vars = {} for var_name in level_dataset.data_vars: var = level_dataset.data_vars[var_name] # For time being, we use the simplest and likely fastest downsampling I can think of downsampled_var = var[..., ::2, ::2] if downsampled_var.shape[-2:] != level_shapes[level]: import warnings warnings.warn(f"unexpected spatial shape for down-sampled variable {var_name!r}:" f" expected {level_shapes[level]}, but found {downsampled_var.shape[-2:]}") downsampled_vars[var_name] = downsampled_var level_dataset = xr.Dataset(downsampled_vars, attrs=level_dataset.attrs) level_dataset = _tile_level_dataset(level_dataset, spatial_tile_shape) # Apply post processor, if any if post_process_level is not None: level_dataset = post_process_level(level_dataset, len(level_datasets), num_levels) # Inform progress monitor, if any if progress_monitor is not None: progress_monitor(level_dataset, len(level_datasets), num_levels) # Collect level dataset level_datasets.append(level_dataset) return level_datasets
[docs]def write_levels(output_path: str, dataset: xr.Dataset = None, input_path: str = None, link_input: bool = False, progress_monitor: PyramidLevelCallback = None, **kwargs) -> List[xr.Dataset]: """ Transform the given dataset given by a *dataset* instance or *input_path* string into the levels of a multi-level pyramid with spatial resolution decreasing by a factor of two in both spatial dimensions and write them to *output_path*. One of *dataset* and *input_path* must be given. :param output_path: Output path :param dataset: Dataset to be converted and written as levels. :param input_path: Input path to a dataset to be transformed and written as levels. :param link_input: Just link the dataset at level zero instead of writing it. :param progress_monitor: An optional progress monitor. :param kwargs: Keyword-arguments accepted by the ``compute_levels()`` function. :return: A list of dataset instances representing the multi-level pyramid. """ if dataset is None and input_path is None: raise ValueError("at least one of dataset or input_path must be given") if link_input and input_path is None: raise ValueError("input_path must be provided to link input") _post_process_level = kwargs.pop("post_process_level") if "post_process_level" in kwargs else None def post_process_level(level_dataset, index, num_levels): if _post_process_level is not None: level_dataset = _post_process_level(level_dataset, index, num_levels) if index == 0 and link_input: with open(os.path.join(output_path, f"{index}.link"), "w") as fp: fp.write(input_path) else: path = os.path.join(output_path, f"{index}.zarr") level_dataset.to_zarr(path) level_dataset.close() level_dataset = xr.open_zarr(path) if progress_monitor is not None: progress_monitor(level_dataset, index, num_levels) return level_dataset if not os.path.exists(output_path): os.makedirs(output_path) if dataset is None: dataset = open_dataset(input_path) return compute_levels(dataset, post_process_level=post_process_level, **kwargs)
[docs]def read_levels(dir_path: str, progress_monitor: PyramidLevelCallback = None) -> List[xr.Dataset]: """ Read the of a multi-level pyramid with spatial resolution decreasing by a factor of two in both spatial dimensions. :param dir_path: The directory path. :param progress_monitor: An optional progress monitor. :return: A list of dataset instances representing the multi-level pyramid. """ file_paths = os.listdir(dir_path) level_paths = {} num_levels = -1 for filename in file_paths: file_path = os.path.join(dir_path, filename) basename, ext = os.path.splitext(filename) if basename.isdigit(): index = int(basename) num_levels = max(num_levels, index + 1) if os.path.isfile(file_path) and ext == ".link": level_paths[index] = (ext, file_path) elif os.path.isdir(file_path) and ext == ".zarr": level_paths[index] = (ext, file_path) if num_levels != len(level_paths): raise ValueError(f"Inconsistent pyramid directory:" f" expected {num_levels} but found {len(level_paths)} entries:" f" {dir_path}") levels = [] for index in range(num_levels): ext, file_path = level_paths[index] if ext == ".link": with open(file_path, "r") as fp: link_file_path = fp.read() dataset = xr.open_zarr(link_file_path) else: dataset = xr.open_zarr(file_path) if progress_monitor is not None: progress_monitor(dataset, index, num_levels) levels.append(dataset) return levels
def _tile_chunk(size, tile_size): last_tile_size = size % tile_size if last_tile_size != 0: return (tile_size,) * (size // tile_size) + (last_tile_size,) return tile_size def _tile_level_dataset(level_dataset, spatial_tile_shape): tile_height, tile_width = spatial_tile_shape # Chunk variables in level dataset according to spatial_tile_shape chunked_vars = {} # Chunk data variables according to tile size for var_name in level_dataset.data_vars: var = level_dataset.data_vars[var_name] height, width = var.shape[-2:] zarr_chunks = (1,) * (var.ndim - 2) + (tile_height,) + (tile_width,) dask_chunks = (1,) * (var.ndim - 2) + (_tile_chunk(height, tile_height),) + (_tile_chunk(width, tile_width),) dask_chunks = {var.dims[i]: dask_chunks[i] for i in range(var.ndim)} chunked_var = var.chunk(chunks=dask_chunks) chunked_var.encoding.update(chunks=zarr_chunks) chunked_vars[var_name] = chunked_var # Make coordinate variable chunks equal to their shape # TODO (forman): find out if chunking the spatial coordinates according to tile size improves performance for var_name in level_dataset.coords: var = level_dataset.coords[var_name] zarr_chunks = var.shape dask_chunks = {var.dims[i]: var.shape[i] for i in range(var.ndim)} chunked_var = var.chunk(chunks=dask_chunks) chunked_var.encoding.update(chunks=zarr_chunks) chunked_vars[var_name] = chunked_var return level_dataset.assign(variables=chunked_vars) def _compute_level_shapes(spatial_shape, spatial_tile_shape, num_levels_max=None) -> List[Tuple[int, int]]: height, width = spatial_shape tile_height, tile_width = spatial_tile_shape num_levels_max = num_levels_max or -1 level_shapes = [(height, width)] while True: width = (width + 1) // 2 height = (height + 1) // 2 if width < tile_width or height < tile_height or num_levels_max == len(level_shapes): break level_shapes.append((height, width)) return level_shapes def _filter_level_source_dataset(dataset, var_names=None, spatial_dims=None, spatial_shape=None, spatial_tile_shape=None): if var_names: var_names = set(var_names) dropped_vars = list(set(dataset.data_vars).difference(var_names)) else: var_names = set(dataset.data_vars) dropped_vars = [] # Collect data variables to be dropped, derive missing information from spatial data variables for var_name in var_names: if var_name not in dataset.data_vars: raise ValueError(f"variable {var_name} not found") var = dataset[var_name] if var.ndim < 2: # Must have at least the two spatial dimensions dropped_vars.append(var_name) continue if spatial_dims is None: spatial_dims = var.dims[-2:] elif spatial_dims != var.dims[-2:]: # Spatial dimensions don't fit dropped_vars.append(var_name) continue if spatial_shape is None: spatial_shape = var.shape[-2:] elif spatial_shape != var.shape[-2:]: # Spatial dimension sizes don't fit dropped_vars.append(var_name) continue if spatial_tile_shape is None and var.chunks is not None: def chunk_to_int(chunk): return chunk if isinstance(chunk, int) else max(chunk) spatial_tile_shape = tuple(map(chunk_to_int, var.chunks[-2:])) return dropped_vars, spatial_shape, spatial_tile_shape