Source code for eradiate.ckd

from __future__ import annotations

import functools
import typing as t
from collections import abc

import attrs
import pint
import pinttr
import xarray as xr

from . import data
from .attrs import documented, parse_docs
from .quad import Quad
from .units import unit_context_config as ucc
from .units import unit_registry as ureg

# ------------------------------------------------------------------------------
#                              CKD bin data classes
# ------------------------------------------------------------------------------


[docs]@parse_docs @attrs.frozen class Bin: """ A data class representing a spectral bin in CKD modes. """ id: str = documented( attrs.field(converter=str), doc="Bin identifier.", type="str", ) wmin: pint.Quantity = documented( pinttr.field( units=ucc.deferred("wavelength"), on_setattr=None, # frozen instance: on_setattr must be disabled ), doc='Bin lower spectral bound.\n\nUnit-enabled field (default: ucc["wavelength"]).', type="quantity", init_type="quantity or float", ) wmax: pint.Quantity = documented( pinttr.field( units=ucc.deferred("wavelength"), on_setattr=None, # frozen instance: on_setattr must be disabled ), doc='Bin upper spectral bound.\n\nUnit-enabled field (default: ucc["wavelength"]).', type="quantity", init_type="quantity or float", ) @wmin.validator @wmax.validator def _wbounds_validator(self, attribute, value): if not self.wmin < self.wmax: raise ValueError( f"while validating {attribute.name}: wmin must be lower than wmax" ) quad: Quad = documented( attrs.field( repr=lambda x: x.str_summary, validator=attrs.validators.instance_of(Quad) ), doc="Quadrature rule attached to the CKD bin.", type=":class:`.Quad`", ) @property def width(self) -> pint.Quantity: """quantity : Bin spectral width.""" return self.wmax - self.wmin @property def wcenter(self) -> pint.Quantity: """quantity : Bin central wavelength.""" return 0.5 * (self.wmin + self.wmax) @property def bindexes(self) -> t.List[Bindex]: """list of :class:`.Bindex` : List of associated bindexes.""" return [Bindex(bin=self, index=i) for i, _ in enumerate(self.quad.nodes)]
[docs] @classmethod def convert(cls, value: t.Any) -> t.Any: """ If ``value`` is a tuple or a dictionary, try to construct a :class:`.Bin` instance from it. Otherwise, return ``value`` unchanged. """ if isinstance(value, tuple): return cls(*value) if isinstance(value, dict): return cls(**value) return value
[docs]@parse_docs @attrs.frozen class Bindex: """ A data class representing a CKD (bin, index) pair. """ bin: Bin = documented( attrs.field(converter=Bin.convert), doc="CKD bin.", type=":class:`.Bin`", ) index: int = documented( attrs.field(), doc="Quadrature point index.", type="int", )
[docs] @classmethod def convert(cls, value) -> t.Any: """ If ``value`` is a tuple or a dictionary, try to construct a :class:`.Bindex` instance from it. Otherwise, return ``value`` unchanged. """ if isinstance(value, tuple): return cls(*value) if isinstance(value, dict): return cls(**value) return value
# ------------------------------------------------------------------------------ # Bin filtering functions # ------------------------------------------------------------------------------ # The following functions implement a set of bin selection routines used to # filter and select bins from a bin set.
[docs]def bin_filter_ids(ids: t.Sequence[str]) -> t.Callable[[Bin], bool]: """ Select bins based on identifiers. Parameters ---------- ids : list of str A sequence of bin identifiers which the defined filter will let through. Returns ------- callable A callable which returns ``True`` *iff* a bin has a valid identifier. """ def filter(value): return value.id in ids return filter
[docs]def bin_filter_interval( wmin: pint.Quantity, wmax: pint.Quantity, endpoints: bool = True ) -> t.Callable[[Bin], bool]: """ Select bins in a wavelength interval. Parameters ---------- wmin : :class:`pint.Quantity` Lower bound of the spectral interval defining the filter. wmax : :class:`pint.Quantity` Upper bound of the spectral interval defining the filter. Must be equal to or greater than ``wmin``. endpoints : bool If ``True``, bins must have at least one of their bounds within the interval to be selected. If ``False``, bins must have both bounds in the interval to be selected. Returns ------- callable A callable which returns ``True`` *iff* a bin is in the specified interval. Raises ------ ValueError If ``wmin > wmax``. """ if wmin > wmax: raise ValueError("wmin must be lower or equal to wmax") if wmin == wmax: def filter(value): return value.wmin < wmin < value.wmax else: if endpoints: def filter(value): return ( (wmin <= value.wmin and value.wmax <= wmax) or (value.wmin < wmin < value.wmax) or (value.wmin < wmax < value.wmax) ) else: def filter(value): return wmin <= value.wmin and value.wmax <= wmax return filter
[docs]def bin_filter(type: str, filter_kwargs: t.Dict[str, t.Any]): """ Create a bin filter function dynamically. Parameters ---------- type : str Filter type. filter_kwargs : dict Keyword arguments passed to the filter generator. Returns ------- callable Generated bin filter function. Notes ----- Valid filter types are: * ``all`` (``lambda x: True``); * ``ids`` (:func:`.bin_filter_ids`); * ``interval`` (:func:`.bin_filter_interval`). See the corresponding API entry for expected arguments. """ if type == "interval": return bin_filter_interval(**filter_kwargs) if type == "ids": return bin_filter_ids(**filter_kwargs) if type == "all": return lambda x: True else: raise ValueError(f"unknown bin filter type {type}")
# ------------------------------------------------------------------------------ # Bin set definitions # ------------------------------------------------------------------------------
[docs]@parse_docs @attrs.frozen class BinSet: """ A data class representing a quadrature definition used in CKD mode. """ id: str = documented( attrs.field(converter=str), doc="Bin set identifier.", type="str", ) quad: Quad = documented( attrs.field(repr=lambda x: x.str_summary), doc="Quadrature rule associated with this spectral bin set.", type=":class:`.Quad`", ) bins: t.Tuple[Bin, ...] = documented( attrs.field( converter=lambda x: tuple( sorted( x, key=lambda y: (y.wmin, y.wmax, y.id) ) # Specify field priority for sorting and exclude quad (which is irrelevant) ), validator=attrs.validators.deep_iterable( member_validator=attrs.validators.instance_of(Bin) ), repr=lambda x: f"tuple<{len(x)}>(" f"Bin({repr(x[0].id)}, ... ), " f"Bin({repr(x[1].id)}, ... ), ... , " f"Bin({repr(x[-2].id)}, ... ), " f"Bin({repr(x[-1].id)}, ... ))" if len(x) >= 5 else f"{tuple(y.id for y in x)}", ), doc="Sequence of bins (automatically ordered upon setting).", type="tuple of :class:`.Bin`", init_type="sequence of :class:`.Bin`", ) @bins.validator def _bins_validator(self, attribute, value): if any(bin.quad is not self.quad for bin in value): raise ValueError( f"while validating {attribute.name}: all defined bins must " "share the same quadrature as their parent bin set" ) @property def bin_ids(self) -> t.List[str]: """ list of str : Return the identifiers of defined spectral bins. """ return [bin.id for bin in self.bins] @property def bin_wmins(self) -> pint.Quantity: """ quantity : Return the lower bounds of defined spectral bins. """ units = ucc.get("wavelength") return ureg.Quantity([bin.wmin.m_as(units) for bin in self.bins], units) @property def bin_wmaxs(self) -> pint.Quantity: """ quantity : Return the upper bounds of defined spectral bins. """ units = ucc.get("wavelength") return ureg.Quantity([bin.wmax.m_as(units) for bin in self.bins], units)
[docs] def filter_bins(self, *filters: t.Callable[[Bin], bool]) -> t.Tuple[Bin, ...]: """ Filter bins based on callables. Parameters ---------- filters : callable One or several callables with signature ``filter(x: Bin) -> bool``. Returns ------- dict[str, :class:`.Bin`] Only bins for which ``filter(bin)`` is ``True`` are returned, ordered by their lower bound. """ selected_bins = [] for bin in self.bins: for filter in filters: if filter(bin): selected_bins.append(bin) break return tuple(sorted(selected_bins, key=lambda x: (x.wmin, x.wmax, x.id)))
[docs] def select_bins( self, *filter_specs: t.Union[ str, t.Callable[[Bin], bool], t.Tuple[str, t.Dict[str, t.Any]], t.Dict, ], ) -> t.Tuple[Bin, ...]: """ Select a subset of CKD bins. This method is a high-level wrapper for :meth:`.filter_bins`. Parameters ---------- filter_specs : sequence of {str, callable, sequence, dict} One or several bin filter specifications. The following are supported: * a string will select a bin with matching ID (internally, this results in a call to :func:`.bin_filter_ids`); * a callable will be directly used; * a (str, dict) will be interpreted as the ``type`` and ``filter_kwargs`` arguments of :func:`.bin_filter`; * a dict with keys ``type`` and ``filter_kwargs`` will be directly forwarded as keyword arguments to :func:`.bin_filter`. Returns ------- dict[str, :class:`.Bin`] Selected bins. """ filters = [] for filter_spec in filter_specs: if isinstance(filter_spec, str): filters.append(bin_filter_ids(ids=(filter_spec,))) elif isinstance(filter_spec, t.Sequence): filters.append(bin_filter(*filter_spec)) elif isinstance(filter_spec, abc.Mapping): filters.append(bin_filter(**filter_spec)) elif callable(filter_spec): filters.append(filter_spec) else: raise ValueError(f"unhandled CKD bin selector {filter_spec}") return self.filter_bins(*filters)
[docs] @classmethod def convert(cls, value) -> t.Any: """ If ``value`` is a string, query quadrature definition database with it. Otherwise, return ``value`` unchanged. """ if isinstance(value, str): return cls.from_db(value) return value
[docs] @staticmethod def from_dataset(id: str, ds: xr.Dataset) -> BinSet: """ Convert a dataset-based bin set definition to a :class:`BinSet` instance. Parameters ---------- id : str Data set identifier. ds : :class:`~xarray.Dataset` Dataset from which bin set definition information is to be extracted. Returns ------- :class:`.BinSet` Bin set definition. """ # Collect quadrature data quad_type = ds.attrs["quadrature_type"] quad_n = ds.attrs["quadrature_n"] quad = Quad.new(quad_type, quad_n) # Collect bin set data bin_ids = ds.bin.values bin_wmin = ureg.Quantity(ds.wmin.values, ds.wmin.attrs["units"]) bin_wmax = ureg.Quantity(ds.wmax.values, ds.wmax.attrs["units"]) # Assemble the data return BinSet( id=id, quad=quad, bins=tuple( Bin(id=id, wmin=wmin, wmax=wmax, quad=quad) for id, wmin, wmax in zip(bin_ids, bin_wmin, bin_wmax) ), )
[docs] @staticmethod @functools.lru_cache(maxsize=128) def from_db(id: str) -> BinSet: """ Get a bin set definition from Eradiate's database. .. note:: This static function is cached using :func:`functools.lru_cache` for optimal performance. Parameters ---------- id : str Data set identifier. The :func:`eradiate.data.open` function will be used to load the requested data set. Returns ------- :class:`.BinSet` Bin set definition. """ with data.open_dataset(f"ckd/bin_sets/{id}.nc") as ds: result = BinSet.from_dataset(id=id, ds=ds) return result
[docs] @staticmethod def from_node_dataset(ds: xr.Dataset): """ Get bin set from node data. Parameters ---------- ds : :class:`~xarray.Dataset` Node data to get the bin set for. Data set attributes must have a ``bin_set`` field referencing a registered bin set definition in the Eradiate database. Returns ------- :class:`.BinSet` Bin set definition associated with the passed CKD node data. """ return BinSet.from_db(ds.attrs["bin_set"])