# Copyright (c) 2018-2024 by xcube team and contributors
# Permissions are hereby granted under the terms of the MIT License:
# https://opensource.org/licenses/MIT.
import itertools
import os
import re
import uuid
import warnings
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import dask.array as da
import dask.array.core as dac
import distributed
import numpy as np
IntTuple = Tuple[int, ...]
SliceTuple = Tuple[slice, ...]
IntIterable = Iterable[int]
IntTupleIterable = Iterable[IntTuple]
SliceTupleIterable = Iterable[SliceTuple]
_CLUSTER_TAGS_ENV_VAR_NAME = "XCUBE_DASK_CLUSTER_TAGS"
_CLUSTER_ACCOUNT_ENV_VAR_NAME = "XCUBE_DASK_CLUSTER_ACCOUNT"
def compute_array_from_func(
func: Callable[..., np.ndarray],
shape: IntTuple,
chunks: IntTuple,
dtype: Any,
name: str = None,
ctx_arg_names: Sequence[str] = None,
args: Sequence[Any] = None,
kwargs: Mapping[str, Any] = None,
) -> da.Array:
"""Compute a dask array using the provided user function
*func*, *shape*, and chunking *chunks*.
The user function is expected to output the array's data
blocks using arguments specified by *ctx_arg_names*, *args*,
and *kwargs* and is expected to return a numpy array.
You can request array and current block context information
by specifying the optional *ctx_arg_names* keyword argument
that is a sequence of names of special arguments passed to
*user_func*. The following are available:
* ``shape``: The array's shape. A tuple of ints.
* ``chunks``: The array's chunks. A tuple of tuple of ints.
* ``dtype``: The array's numpy data type.
* ``name``: The array's name. A string or ``None``.
* ``block_id``: The block's unique ID. An integer number
ranging from zero to number of blocks minus one.
* ``block_index``: The block's index as a tuple of ints.
* ``block_shape``: The block's shape as a tuple of ints.
* ``block_slices``: The block's shape as a tuple of int pair tuples.
Args:
func: User function that is called for each block of the
array using arguments specified by *ctx_arg_names*,
*args*, and *kwargs*. It must return a numpy array of
shape "block_shape" and type *dtype*.
shape: The array's shape. A tuple of sizes for each
array dimension.
chunks: The array's chunking. A tuple of chunk sizes for
each array dimension. Must be of same length as *shape*.
dtype: The array's numpy data type.
name: The array's name.
ctx_arg_names: Sequence names of arguments that are passed
before *args* and *kwargs* to the user function.
args: Arguments passed to the user function.
kwargs: Keyword-arguments passed to the user function.
Returns: A chunked dask array.
"""
ctx_arg_names = ctx_arg_names or []
args = args or []
kwargs = kwargs or {}
chunk_sizes = tuple(get_chunk_sizes(shape, chunks))
chunk_counts = tuple(get_chunk_counts(shape, chunks))
block_indexes, block_shapes, block_slices = get_block_iterators(chunk_sizes)
ctx_values = dict(
shape=tuple(shape),
chunks=chunk_sizes,
dtype=dtype,
name=name,
)
blocks = _NestedList(shape=chunk_counts)
block_id = 0
for chunk_index, chunk_shape, block_slices in zip(
block_indexes, block_shapes, block_slices
):
ctx_values.update(
block_id=block_id,
block_index=tuple(chunk_index),
block_shape=tuple(chunk_shape),
block_slices=tuple(
(chunk_slice.start, chunk_slice.stop) for chunk_slice in block_slices
),
)
ctx_args = [ctx_values[ctx_arg_name] for ctx_arg_name in ctx_arg_names]
block_id += 1
# We use our own name here, because dac.from_func() tokenizes args which for some reason takes forever
block = dac.from_func(
func,
shape=chunk_shape,
dtype=dtype,
name=f"rectify_{name}-{uuid.uuid4()}",
args=(*ctx_args, *args),
kwargs=kwargs,
)
blocks[chunk_index] = block
return da.block(blocks.data)
def get_block_iterators(
chunk_sizes: IntTupleIterable,
) -> Tuple[IntTupleIterable, IntTupleIterable, SliceTupleIterable]:
chunk_sizes = tuple(chunk_sizes)
chunk_slices_tuples = get_chunk_slice_tuples(chunk_sizes)
chunk_ranges = get_chunk_ranges(chunk_sizes)
block_indexes = itertools.product(*chunk_ranges)
block_shapes = itertools.product(*chunk_sizes)
block_slices = itertools.product(*chunk_slices_tuples)
return block_indexes, block_shapes, block_slices
def get_chunk_sizes(shape: IntTuple, chunks: IntTuple) -> IntTupleIterable:
for s, c in zip(shape, chunks):
n = s // c
if n * c < s:
yield (c,) * n + (s % c,)
else:
yield (c,) * n
def get_chunk_counts(shape: IntTuple, chunks: IntTuple) -> Iterable[int]:
for s, c in zip(shape, chunks):
yield (s + c - 1) // c
def get_chunk_ranges(chunk_size_tuples: IntTupleIterable) -> Iterable[range]:
return (range(len(chunk_size_tuple)) for chunk_size_tuple in chunk_size_tuples)
def get_chunk_slice_tuples(chunk_size_tuples: IntTupleIterable) -> SliceTupleIterable:
return (
tuple(get_chunk_slices(chunk_size_tuple))
for chunk_size_tuple in chunk_size_tuples
)
def get_chunk_slices(chunk_sizes: Sequence[int]) -> Iterable[slice]:
stop = 0
for i in range(len(chunk_sizes)):
start = stop
stop = start + chunk_sizes[i]
yield slice(start, stop)
[docs]
def new_cluster(
provider: str = "coiled",
name: Optional[str] = None,
software: Optional[str] = None,
n_workers: int = 4,
resource_tags: Optional[Dict[str, str]] = None,
account: str = None,
region: str = "eu-central-1",
**kwargs,
) -> distributed.deploy.Cluster:
"""Create a new Dask cluster.
Cloud resource tags can be specified in an environment variable
XCUBE_DASK_CLUSTER_TAGS in the format
``tag_1=value_1:tag_2=value_2:...:tag_n=value_n``. In case of
conflicts, tags specified in ``resource_tags`` will override tags
specified by the environment variable.
The cluster provider account name can be specified in an environment
variable ``XCUBE_DASK_CLUSTER_ACCOUNT``. If the ``account`` argument is
given to ``new_cluster``, it will override the value from the environment
variable.
Args:
provider: identifier of the provider to use. Currently, only
'coiled' is supported.
name: name to use as an identifier for the cluster
software: identifier for the software environment to be used.
n_workers: number of workers in the cluster
resource_tags: tags to apply to the cloud resources forming the
cluster
account: cluster provider account name
**kwargs: further named arguments will be passed on to the
cluster creation function
region: default region where workers of the cluster will be
deployed set to eu-central-1
"""
if resource_tags is None:
resource_tags = {}
if _CLUSTER_ACCOUNT_ENV_VAR_NAME in os.environ:
account_from_env_var = os.environ[_CLUSTER_ACCOUNT_ENV_VAR_NAME]
else:
account_from_env_var = None
warnings.warn(
f"Environment variable {_CLUSTER_ACCOUNT_ENV_VAR_NAME}"
f" not set; cluster account name may be incorrect."
)
cluster_account = (
account
if account is not None
else account_from_env_var
if account_from_env_var is not None
else "bc"
)
if provider == "coiled":
try:
import coiled
except ImportError as e:
raise ImportError(
f"provider 'coiled' requires package" f"'coiled' to be installed"
) from e
if software is None and "JUPYTER_IMAGE" in os.environ:
# If the JUPYTER_IMAGE environment variable is set, we're
# presumably in a Z2JH deployment and can base a
# Coiled environment on the same image.
# First we construct an identifier from the user image specifier.
current_image = os.environ["JUPYTER_IMAGE"]
software = re.sub(
"[:.]",
"-",
re.search(r"/([^/]+)$", current_image).group(1),
)
# If the referenced software environment doesn't exist yet as a
# Coiled environment, create it from the currently used image.
available_environments = coiled.list_software_environments(
account=account
).keys()
if software not in available_environments:
coiled.create_software_environment(
name=software, container=current_image
)
# If software is (still) None, Coiled will try to mirror the current
# environment automagically.
coiled_params = dict(
n_workers=n_workers,
environ=None,
tags=_collate_cluster_resource_tags(resource_tags),
account=cluster_account,
name=name,
software=software,
use_best_zone=True,
compute_purchase_option="spot_with_fallback",
shutdown_on_close=True,
region=region,
)
coiled_params.update(kwargs)
return coiled.Cluster(**coiled_params)
raise NotImplementedError(f"Unknown provider {provider!r}")
def _collate_cluster_resource_tags(extra_tags: Dict[str, str]) -> Dict[str, str]:
fallback_tags = {
"cost-center": "unknown",
"environment": "dev",
"creator": "auto",
"purpose": "xcube dask cluster",
"user": (
os.environ.get("JUPYTERHUB_USER") # JupyterHub
or os.environ.get("USER") # Unixes
or os.environ.get("USERNAME") # Windows
or os.getlogin()
or ""
),
}
if _CLUSTER_TAGS_ENV_VAR_NAME in os.environ:
kvps = os.environ[_CLUSTER_TAGS_ENV_VAR_NAME].split(":")
env_var_tags = {
(parts := kvp.split("=", maxsplit=1))[0]: parts[1] for kvp in kvps
}
else:
warnings.warn(
f"Environment variable {_CLUSTER_TAGS_ENV_VAR_NAME}"
f" not set; cluster resource tags may be missing."
)
env_var_tags = {}
return fallback_tags | env_var_tags | extra_tags
class _NestedList:
"""Utility class whose instances are used as input to dask.block()."""
def __init__(self, shape: Sequence[int], fill_value: Any = None):
self._shape = tuple(shape)
self._data = self._new_data(shape, len(shape), fill_value, 0)
@classmethod
def _new_data(
cls, shape: Sequence[int], ndim: int, fill_value: Any, dim: int
) -> Union[List[List], List[Any]]:
return [
(
cls._new_data(shape, ndim, fill_value, dim + 1)
if dim < ndim - 1
else fill_value
)
for _ in range(shape[dim])
]
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@property
def data(self) -> Union[List[List], List[Any]]:
return self._data
def __len__(self) -> int:
return len(self._data)
def __setitem__(self, index: Union[int, slice, tuple], value: Any):
data = self._data
if isinstance(index, tuple):
n = len(index)
for i in range(n - 1):
data = data[index[i]]
data[index[n - 1]] = value
else:
data[index] = value
def __getitem__(self, index: Union[int, slice, tuple]) -> Any:
data = self._data
if isinstance(index, tuple):
n = len(index)
for i in range(n - 1):
data = data[index[i]]
return data[index[n - 1]]
else:
return data[index]