Source code for eradiate.pipelines._gather

from __future__ import annotations

import logging

import attrs
import numpy as np
import xarray as xr
from pinttr.util import always_iterable

import eradiate

from ._core import PipelineStep
from ..attrs import documented, parse_docs
from ..cfconventions import ATTRIBUTES
from ..exceptions import UnsupportedModeError
from ..kernel import bitmap_to_dataset
from ..spectral.ckd import BinSet
from ..units import symbol
from ..units import unit_context_config as ucc

logger = logging.getLogger(__name__)


def _spectral_dims():
    if eradiate.mode().is_mono:
        return (("w", ATTRIBUTES["radiation_wavelength"]),)
    elif eradiate.mode().is_ckd:
        return (
            ("w", ATTRIBUTES["radiation_wavelength"]),
            ("g", ATTRIBUTES["quantile"]),
        )
    else:
        raise UnsupportedModeError


[docs]@parse_docs @attrs.define class GatherMono(PipelineStep): """ Gather raw kernel results (output as nested dictionaries) into an xarray dataset. This pipeline step takes a nested dictionary produced by the parametric loop of an :class:`.Experiment` and repackages it as a :class:`~xarray.Dataset`. The top-level spectral index is mapped to mode-dependent spectral coordinates. Film dimensions are left unmodified and retain their metadata. An ``img`` variable holds sensor values. An ``spp`` variable holds the sample count. """ var: str | tuple[str, dict] = documented( attrs.field(default="img"), default='"img"', type="str or tuple[str, dict]", init_type="str or tuple[str, dict], optional", doc="Name of the variable containing sensor data. Optionally, a " "(name, metadata) pair can be passed.", )
[docs] def transform(self, x: dict) -> xr.Dataset: # Basic preparation spectral_dims = [] spectral_dim_metadata = {} for y in _spectral_dims(): if isinstance(y, str): spectral_dims.append(y) spectral_dim_metadata[y] = {} else: spectral_dims.append(y[0]) spectral_dim_metadata[y[0]] = y[1] sensor_datasets = [] # Loop on spectral indexes for siah, result_dict in x.items(): if eradiate.mode().is_mono: spectral_index = siah elif eradiate.mode().is_ckd: spectral_index = siah ds = bitmap_to_dataset(result_dict["bitmap"]) spp = result_dict["spp"] # Set spectral coordinates all_coords = { spectral_dim: [spectral_coord] for spectral_dim, spectral_coord in zip( spectral_dims, always_iterable(spectral_index) ) } # Add spectral and sensor dimensions to img array ds["img"] = ds.img.expand_dims(dim=all_coords) # Package spp in a data array all_dims = list(all_coords.keys()) ds["spp"] = (all_dims, np.reshape(spp, [1 for _ in all_dims])) sensor_datasets.append(ds) # Combine all the data with xr.set_options(keep_attrs=True): result = xr.combine_by_coords(sensor_datasets) # Drop "channel" dimension when using a monochromatic Mitsuba variant if eradiate.mode().check(mi_color_mode="mono"): result = result.squeeze("channel", drop=True) for spectral_dim in spectral_dims: result[spectral_dim].attrs = spectral_dim_metadata[spectral_dim] # Apply metadata to data variables if isinstance(self.var, str): var = self.var var_metadata = {} else: var = self.var[0] var_metadata = self.var[1] result = result.rename({"img": var}) result[var].attrs.update(var_metadata) return result
[docs]@parse_docs @attrs.define class GatherCKD(PipelineStep): """ Gather raw kernel results into an xarray dataset. """ binset: BinSet = documented( attrs.field(validator=attrs.validators.instance_of(BinSet)), doc="Bin set.", type=":class:`.BinSet`", ) var: tuple[str, dict] = documented( attrs.field(default="img"), default='"img"', type="tuple[str, dict]", init_type="tuple[str, dict], optional", doc="Variable name containing sensor data and metadata.", )
[docs] def transform(self, x: dict) -> xr.Dataset: # transform 'x' into a list of 'xr.Dataset' where each dataset # corresponds to a spectral index logger.debug("gather_ckd: begin") datasets = [] # x is a dictionary whose keys are spectral indexes as hashable tuples # and whose values are dictionaries with keys "bitmap" and "spp" ix = 0 bins = self.binset.bins # sort bins by wavelength bins = sorted(bins, key=lambda b: b.wcenter) items = list(x.items()) for i, _bin in enumerate(bins): # bin (w) loop ng = _bin.quad.weights.size _datasets = [] for _ in range(ng): # g loop item = items[ix] siah, result_dict = item w, g = siah # wavelength, quantile pair bitmap = result_dict["bitmap"] spp = result_dict["spp"] dataset = bitmap_to_dataset(bitmap) # expand dimensions of 'img' data variable to include 'w' and 'g' dataset["img"] = dataset.img.expand_dims(dim={"w": [w], "g": [g]}) # Drop "channel" dimension when using a monochromatic Mitsuba variant if eradiate.mode().check(mi_color_mode="mono"): dataset = dataset.squeeze("channel", drop=True) # self.var is a tuple (name, metadata) name, metadata = self.var dataset = dataset.rename({"img": name}) dataset[name].attrs.update(metadata) _datasets.append(dataset) ix += 1 # concatenate along 'g' ds = xr.concat(_datasets, dim="g") # compute quadrature # this is a weighted sum array reduction: # https://docs.xarray.dev/en/stable/user-guide/computation.html#weighted-array-reductions # normalise weights to the [0, 1] g-interval weights_values = 0.5 * _bin.quad.weights weights = xr.DataArray(weights_values, dims=["g"], coords={"g": ds.g}) with xr.set_options(keep_attrs=True): weighted = ds[name].weighted(weights) weighted_sum = weighted.sum(dim="g") # create dataset for current bin (w) dataset = xr.Dataset() dataset[name] = weighted_sum # add 'spp' data variable dims = dataset.dims dataset["spp"] = (dims, spp * np.ones(tuple(dims.values()))) # add 'wbounds' data variable, and 'wbv' coordinate wbounds = np.stack([_bin.wmin, _bin.wmax]) wunits = ucc.get("wavelength") dataset["wbounds"] = ( ["wbv", "w"], wbounds.m_as(wunits).reshape((2, 1)), { "standard_name": "radiation_wavelength_bound", "long_name": "wavelength bound", "units": symbol(wunits), }, ) dataset["wbv"] = (["wbv"], ["lower", "upper"]) datasets.append(dataset) with xr.set_options(keep_attrs=True): result = xr.concat(datasets, dim="w") # add metadata for 'w' result["w"].attrs.update(ATTRIBUTES["radiation_wavelength"]) logger.debug("gather: end") return result