Source code for eradiate.pipelines._gather
from __future__ import annotations
import logging
import attrs
import numpy as np
import xarray as xr
from pinttr.util import always_iterable
import eradiate
from ._core import PipelineStep
from ..attrs import documented, parse_docs
from ..exceptions import UnsupportedModeError
from ..kernel import bitmap_to_dataset
from ..units import symbol
from ..units import unit_context_config as ucc
logger = logging.getLogger(__name__)
def _spectral_dims():
if eradiate.mode().is_mono:
return (
(
"w",
{
"standard_name": "radiation_wavelength",
"long_name": "wavelength",
"units": symbol(ucc.get("wavelength")),
},
),
)
elif eradiate.mode().is_ckd:
return (
("bin", {"standard_name": "ckd_bin", "long_name": "CKD bin"}),
("index", {"standard_name": "ckd_index", "long_name": "CKD index"}),
)
else:
raise UnsupportedModeError
[docs]@parse_docs
@attrs.define
class Gather(PipelineStep):
"""
Gather raw kernel results (output as nested dictionaries) into an xarray
dataset.
This pipeline step takes a nested dictionary produced by the parametric loop
of an :class:`.Experiment` and repackages it as a :class:`~xarray.Dataset`.
The top-level spectral index is mapped to mode-dependent spectral
coordinates. Film dimensions are left unmodified and retain their
metadata.
An ``img`` variable holds sensor values. An ``spp`` variable holds the
sample count.
"""
var: str | tuple[str, dict] = documented(
attrs.field(default="img"),
default='"img"',
type="str or tuple[str, dict]",
init_type="str or tuple[str, dict], optional",
doc="Name of the variable containing sensor data. Optionally, a "
"(name, metadata) pair can be passed.",
)
[docs] def transform(self, x: dict) -> xr.Dataset:
logger.debug("gather: begin")
# Basic preparation
spectral_dims = []
spectral_dim_metadata = {}
for y in _spectral_dims():
if isinstance(y, str):
spectral_dims.append(y)
spectral_dim_metadata[y] = {}
else:
spectral_dims.append(y[0])
spectral_dim_metadata[y[0]] = y[1]
sensor_datasets = []
# Loop on spectral indexes
for siah, result_dict in x.items():
if eradiate.mode().is_mono:
spectral_index = siah
elif eradiate.mode().is_ckd:
spectral_index = (
str(int(siah[0])), # TODO: PR#311 hack
siah[1],
)
ds = bitmap_to_dataset(result_dict["bitmap"])
spp = result_dict["spp"]
# Set spectral coordinates
all_coords = {
spectral_dim: [spectral_coord]
for spectral_dim, spectral_coord in zip(
spectral_dims, always_iterable(spectral_index)
)
}
# Add spectral and sensor dimensions to img array
ds["img"] = ds.img.expand_dims(dim=all_coords)
# Package spp in a data array
all_dims = list(all_coords.keys())
ds["spp"] = (all_dims, np.reshape(spp, [1 for _ in all_dims]))
sensor_datasets.append(ds)
# Combine all the data
with xr.set_options(keep_attrs=True):
result = xr.combine_by_coords(sensor_datasets)
# Drop "channel" dimension when using a monochromatic Mitsuba variant
if eradiate.mode().check(mi_color_mode="mono"):
result = result.squeeze("channel", drop=True)
for spectral_dim in spectral_dims:
result[spectral_dim].attrs = spectral_dim_metadata[spectral_dim]
# Apply metadata to data variables
if isinstance(self.var, str):
var = self.var
var_metadata = {}
else:
var = self.var[0]
var_metadata = self.var[1]
result = result.rename({"img": var})
result[var].attrs.update(var_metadata)
logger.debug("gather: end")
return result