Source code for eradiate.experiments._core

from __future__ import annotations

import datetime
import logging
import typing as t
import warnings
from abc import ABC, abstractmethod

import attrs
import mitsuba as mi
import numpy as np
import pinttr
import xarray as xr
from hamilton.driver import Driver

import eradiate

from .. import converters, validators
from .. import pipelines as pl
from ..attrs import AUTO, define, documented
from ..contexts import KernelContext
from ..exceptions import UnsupportedModeError
from ..kernel import (
    KernelDictTemplate,
    MitsubaObjectWrapper,
    UpdateMapTemplate,
    mi_render,
    mi_traverse,
)
from ..quad import Quad
from ..rng import SeedState
from ..scenes.core import Scene, SceneElement, get_factory, traverse
from ..scenes.illumination import (
    AbstractDirectionalIllumination,
    ConstantIllumination,
    DirectionalIllumination,
    illumination_factory,
)
from ..scenes.integrators import Integrator, integrator_factory
from ..scenes.measure import (
    Measure,
    MultiDistantMeasure,
    measure_factory,
)
from ..spectral.ckd_quad import CKDQuadConfig
from ..spectral.grid import CKDSpectralGrid, MonoSpectralGrid, SpectralGrid
from ..spectral.index import CKDSpectralIndex, MonoSpectralIndex, SpectralIndex
from ..units import unit_registry as ureg
from ..util.misc import deduplicate_sorted, onedict_value

logger = logging.getLogger(__name__)


[docs] @define class Experiment(ABC): """ Abstract base class for all Eradiate experiments. An experiment consists of a high-level scene specification parametrized by natural user input, a processing and post-processing pipeline, and a result storage data structure. """ # Internal Mitsuba scene. This member is not set by the end-user, but rather # by the Experiment itself during initialization. mi_scene: MitsubaObjectWrapper | None = attrs.field( default=None, repr=False, init=False, ) measures: list[Measure] = documented( attrs.field( factory=lambda: [MultiDistantMeasure()], converter=lambda value: ( [measure_factory.convert(x) for x in pinttr.util.always_iterable(value)] if not isinstance(value, dict) else [measure_factory.convert(value)] ), validator=attrs.validators.deep_iterable( member_validator=attrs.validators.instance_of(Measure) ), ), doc="List of measure specifications. The passed list may contain " "dictionaries, which will be interpreted by " ":data:`.measure_factory`. " "Optionally, a single :class:`.Measure` or dictionary specification " "may be passed and will automatically be wrapped into a list.", type="list of :class:`.Measure`", init_type="list of :class:`.Measure` or list of dict or " ":class:`.Measure` or dict", default=":class:`MultiDistantMeasure() <.MultiDistantMeasure>`", ) integrator: Integrator = documented( attrs.field( default=AUTO, converter=converters.auto_or(integrator_factory.convert), validator=validators.auto_or( attrs.validators.instance_of(Integrator), ), ), doc="Monte Carlo integration algorithm specification. " "This parameter can be specified as a dictionary which will be " "interpreted by :data:`.integrator_factory`." "The integrator defaults to :data:`AUTO`, which will choose the appropriate " "integrator depending on the experiment's configuration. ", type=":class:`.Integrator` or AUTO", init_type=":class:`.Integrator` or dict or AUTO", default="AUTO", ) # Storage for results, for each computed measure _results: dict[str, xr.Dataset] = attrs.field(factory=dict, repr=False) @property def results(self) -> dict[str, xr.Dataset]: """ Post-processed simulation results. Returns ------- dict[str, Dataset] Dictionary mapping measure IDs to xarray datasets. """ return self._results _background_spectral_grid: SpectralGrid = documented( attrs.field( default=AUTO, validator=validators.auto_or(attrs.validators.instance_of(SpectralGrid)), repr=False, ), doc="Background spectral grid. " "If the value is :data:`.AUTO`, the background spectral grid is " "automatically generated depending on the active mode and internal " "experiment constraints. Otherwise, the value must be convertible to " "a :class:`.SpectralGrid` instance.", type=".SpectralGrid or AUTO", init_type=".SpectralGrid or AUTO", default="AUTO", ) @property def background_spectral_grid(self) -> SpectralGrid: return self._background_spectral_grid # Grid used to walk the spectral dimension for each measure. # Set upon initialization by the '_normalize_spectral()' method. _spectral_grids: dict[int, SpectralGrid] = attrs.field( factory=dict, init=False, repr=False ) @property def spectral_grids(self) -> dict[int, SpectralGrid]: """ A dictionary mapping measure index to the associated spectral grid. """ return self._spectral_grids ckd_quad_config: CKDQuadConfig = documented( attrs.field( factory=lambda: CKDQuadConfig(ng_max=16), converter=CKDQuadConfig.convert, validator=attrs.validators.instance_of(CKDQuadConfig), ), doc="CKD quadrature rule generation configuration.", type=".CKDQuadConfig", init_type=".CKDQuadConfig or dict", ) # CKD quadrature configuration for each bin. # Set upon initialization by the '_normalize_spectral()' method. _ckd_quads: dict[int, list[Quad]] = attrs.field( factory=dict, init=False, repr=False ) @property def ckd_quads(self) -> dict[int, list[Quad]]: """ A dictionary mapping measure index to the associated CKD quadrature rule (if relevant). """ return self._ckd_quads def __attrs_post_init__(self): self._normalize_spectral() def _normalize_spectral(self) -> None: """ Assemble a spectral grid based on the various elements in the scene. """ # Collect atmosphere-based grid if relevant atmosphere = getattr(self, "atmosphere", None) abs_db = None if atmosphere is not None: abs_db = getattr(atmosphere, "absorption_data", None) if abs_db is not None: if self._background_spectral_grid is not AUTO: warnings.warn( "User-specified a background spectral grid is overridden by " "atmosphere spectral grid." ) self._background_spectral_grid = SpectralGrid.from_absorption_database( atmosphere.absorption_data ) # If needed, set the background grid if self._background_spectral_grid is AUTO: self._background_spectral_grid = SpectralGrid.default() # Select subparts of the grid that are covered by the SRF self._spectral_grids = { i: self.background_spectral_grid.select(measure.srf) for i, measure in enumerate(self.measures) } # Get quadrature rules for all bins ckd_quads = {} for i, measure in enumerate(self.measures): if eradiate.mode().is_ckd: spectral_grid: CKDSpectralGrid = self._spectral_grids[i] ckd_quads[i] = [ x[1] for x in spectral_grid.walk_quads(self.ckd_quad_config, abs_db) ] else: ckd_quads[i] = [] self._ckd_quads = ckd_quads
[docs] def clear(self) -> None: """ Clear previous experiment results and reset internal state. """ self.results.clear() for measure in self.measures: measure.mi_results.clear()
[docs] @abstractmethod def init(self) -> None: """ Generate kernel dictionary and initialize Mitsuba scene. """ pass
[docs] @abstractmethod def process( self, spp: int = 0, seed_state: SeedState | None = None, ) -> None: """ Run simulation and collect raw results. Parameters ---------- spp : int, optional Sample count. If set to 0, the value set in the original scene definition takes precedence. seed_state : :class:`.SeedState`, optional Seed state used to generate seeds to initialize Mitsuba's RNG at every iteration of the parametric loop. If unset, Eradiate's :attr:`root seed state <.root_seed_state>` is used. """ pass
[docs] @abstractmethod def postprocess(self) -> None: """ Post-process raw results and store them in :attr:`results`. """ pass
[docs] @abstractmethod def pipeline(self, measure: Measure | int) -> Driver: """ Return the post-processing pipeline for a given measure. Parameters ---------- measure : .Measure or int Measure for which the pipeline is generated. Returns ------- hamilton.driver.Driver """ pass
@property @abstractmethod def context_init(self) -> KernelContext: """ Return a single context used for scene initialization. """ pass @property @abstractmethod def contexts(self) -> list[KernelContext]: """ Return a list of contexts used for processing. """ pass
def _extra_objects_converter(value: dict | None) -> dict: if not value: return {} result = {} for key, element_spec in value.items(): if isinstance(element_spec, dict): element_spec = element_spec.copy() element_type = element_spec.pop("factory") factory = get_factory(element_type) result[key] = factory.convert(element_spec) else: result[key] = element_spec return result
[docs] @define class EarthObservationExperiment(Experiment, ABC): """ Abstract based class for experiments illuminated by a distant directional emitter. """ extra_objects: dict[str, SceneElement] = documented( attrs.field( default=None, converter=_extra_objects_converter, validator=attrs.validators.deep_mapping( key_validator=attrs.validators.instance_of(str), value_validator=attrs.validators.instance_of(SceneElement), ), ), doc="Dictionary of extra objects to be added to the scene. " "The keys of this dictionary are used to identify the objects " "in the kernel dictionary.", type="dict", init_type="dict or None", default="None", ) illumination: AbstractDirectionalIllumination | ConstantIllumination = documented( attrs.field( factory=DirectionalIllumination, converter=illumination_factory.convert, validator=attrs.validators.instance_of( (AbstractDirectionalIllumination, ConstantIllumination) ), ), doc="Illumination specification. " "This parameter can be specified as a dictionary which will be " "interpreted by :data:`.illumination_factory`.", type=":class:`.AbstractDirectionalIllumination` or " ":class:`.ConstantIllumination`", init_type=":class:`.DirectionalIllumination` or " ":class:`.ConstantIllumination` or dict", default=":class:`DirectionalIllumination() <.DirectionalIllumination>`", ) kdict: KernelDictTemplate = documented( attrs.field(factory=KernelDictTemplate, converter=KernelDictTemplate), doc="Additional kernel dictionary template appended to the " "experiment-controlled template.", type=".KernelDictTemplate", init_type="mapping", default="{}", ) kpmap: UpdateMapTemplate = documented( attrs.field(factory=UpdateMapTemplate, converter=UpdateMapTemplate), doc="Additional scene parameter update map template appended to the " "experiment-controlled template.", type=".UpdateMapTemplate", init_type="mapping", default="{}", ) def kdict_base(self) -> KernelDictTemplate: # This is inefficient and exists at the moment only for debugging purposes return traverse(self.scene)[0] def kdict_full(self) -> KernelDictTemplate: # Return the user-defined kdict template merged with additional scene # element contributions kdict = self.kdict_base() kdict.update(self.kdict) return kdict def kpmap_base(self) -> UpdateMapTemplate: # This is inefficient and exists at the moment only for debugging purposes return traverse(self.scene)[1] def kpmap_full(self) -> UpdateMapTemplate: # Return the user-defined kpmap template merged with additional scene # element contributions kpmap = self.kpmap_base() kpmap.update(self.kpmap) return kpmap def _dataset_metadata(self, measure: Measure) -> dict[str, str]: """ Generate additional metadata applied to dataset after post-processing. Parameters ---------- measure : :class:`.Measure` Measure for which the metadata is created. Returns ------- dict[str, str] Metadata to be attached to the produced dataset. """ return { "convention": "CF-1.10", "source": f"eradiate, version {eradiate.__version__}", "history": f"{datetime.datetime.now(datetime.UTC).replace(microsecond=0).isoformat()}" f" - data creation - {self.__class__.__name__}.postprocess()", "references": "", }
[docs] def spectral_indices(self, measure_index: int) -> t.Generator[SpectralIndex]: """ Generate spectral indices for a given measure. Parameters ---------- measure_index : int Measure index for which spectral indices are generated. Yields ------ :class:`.SpectralIndex` Spectral index. """ if eradiate.mode().is_mono: spectral_grid: MonoSpectralGrid = self.spectral_grids[measure_index] def generator(): yield from spectral_grid.walk_indices() elif eradiate.mode().is_ckd: spectral_grid: CKDSpectralGrid = self.spectral_grids[measure_index] quad_config = self.ckd_quad_config try: abs_db = self.atmosphere.abs_db except ( AttributeError ): # There is either no atmosphere or no absorption database abs_db = None def generator(): yield from spectral_grid.walk_indices(quad_config, abs_db) else: raise UnsupportedModeError yield from generator()
@property def context_init(self): # Inherit docstring return KernelContext( si=self.spectral_indices(0).__next__(), kwargs=self._context_kwargs ) @property @abstractmethod def _context_kwargs(self) -> dict[str, t.Any]: pass @property def contexts(self) -> list[KernelContext]: # Inherit docstring # Collect contexts from all measures sis = [] for measure_index, _ in enumerate(self.measures): _si = list(self.spectral_indices(measure_index)) sis.extend(_si) # Sort and remove duplicates key = { MonoSpectralIndex: lambda si: si.w.m, CKDSpectralIndex: lambda si: (si.w.m, si.g), }[SpectralIndex.subtypes.resolve()] sis = deduplicate_sorted( sorted(sis, key=key), cmp=lambda x, y: key(x) == key(y) ) return [KernelContext(si, kwargs=self._context_kwargs) for si in sis] @property @abstractmethod def scene_objects(self) -> dict[str, SceneElement]: pass @property def scene(self) -> Scene: """ Return a scene object used for kernel dictionary template and parameter table generation. """ return Scene(objects={**self.scene_objects, **self.extra_objects})
[docs] def init(self): # Inherit docstring logger.info("Initializing kernel scene") kdict_template, umap_template = traverse(self.scene) kdict_template.update(self.kdict) umap_template.update(self.kpmap) try: self.mi_scene = mi_traverse( mi.load_dict(kdict_template.render(ctx=self.context_init)), umap_template=umap_template, ) except RuntimeError as e: raise RuntimeError(f"(while loading kernel scene dictionary){e}") from e # Remove unused elements from Mitsuba scene parameter table self.mi_scene.drop_parameters()
[docs] def process(self, spp: int = 0, seed_state: SeedState | None = None) -> None: # Inherit docstring # Set up Mitsuba scene if self.mi_scene is None: self.init() # Run Mitsuba for each context logger.info("Launching simulation") mi_results = mi_render( self.mi_scene, self.contexts, seed_state=seed_state, spp=spp ) # Assign collected results to the appropriate measure sensor_to_measure: dict[str, Measure] = { measure.sensor_id: measure for measure in self.measures } def convert_to_y_format(img): img_np = np.array(img, copy=False)[:, :, [0]] return mi.Bitmap(img_np, mi.Bitmap.PixelFormat.Y) # create a mapping from bitmap names to result names mapping = {} if self.integrator.stokes: stokes = ["S0", "S1", "S2", "S3"] iquv = ["I", "Q", "U", "V"] if self.integrator.moment: stokes = ["nested." + s for s in stokes] stokes += ["m2_" + s for s in stokes] iquv += ["m2_" + s for s in iquv] for s, i in zip(stokes, iquv): mapping[s] = i else: mapping = {"<root>": "bitmap"} if self.integrator.moment: mapping["m2_nested"] = "m2" # gather results and info from measures for ctx_index, spectral_group_dict in mi_results.items(): for sensor_id, mi_bitmap in spectral_group_dict.items(): measure = sensor_to_measure[sensor_id] result_imgs = {"spp": spp if spp > 0 else measure.spp} splits = mi_bitmap.split() for split in splits: if split[0] in mapping: img = split[1] # convert any result that has more than one channel if img.pixel_format() != mi.Bitmap.PixelFormat.Y: img = convert_to_y_format(img) result_imgs[mapping[split[0]]] = img measure.mi_results[ctx_index] = result_imgs
[docs] def postprocess(self) -> None: # Inherit docstring logger.info("Post-processing results") measures = self.measures # Run pipelines for i, measure in enumerate(measures): drv: Driver = self.pipeline(measure) inputs = self._pipeline_inputs(i) outputs = pl.outputs(drv) result = drv.execute(final_vars=outputs, inputs=inputs) self.results[measure.id] = xr.Dataset({var: result[var] for var in outputs})
[docs] def pipeline(self, measure: Measure | int) -> Driver: # Inherit docstring if isinstance(measure, int): measure = self.measures[measure] config = pl.config(measure, integrator=self.integrator) return eradiate.pipelines.driver(config)
def _pipeline_inputs(self, i_measure: int): # This convenience function collects pipeline inputs for a specific measure measure = self.measures[i_measure] result = { "bitmaps": measure.mi_results, "spectral_grid": self.spectral_grids[i_measure], "ckd_quads": self.ckd_quads[i_measure], "illumination": self.illumination, "srf": measure.srf, } config = pl.config(measure) if config.get("add_viewing_angles", False): result["angles"] = measure.viewing_angles.m_as(ureg.deg) else: result["viewing_angles"] = None return result
# ------------------------------------------------------------------------------ # Experiment runner # ------------------------------------------------------------------------------
[docs] def run( exp: Experiment, spp: int = 0, seed_state: SeedState | None = None, ) -> xr.Dataset | dict[str, xr.Dataset]: """ Run an Eradiate experiment. This function performs kernel scene assembly, runs the computation and post-processes the raw results. The output consists of one or several xarray datasets. Parameters ---------- exp : Experiment Reference to the experiment object which will be processed. spp : int, optional, default: 0 Optional parameter to override the number of samples per pixel for all computed measures. If set to 0, the configured value for each measure takes precedence. seed_state : :class:`.SeedState`, optional Seed state used to generate seeds to initialize Mitsuba's RNG at every iteration of the parametric loop. If unset, Eradiate's :attr:`root seed state <.root_seed_state>` is used. Returns ------- Dataset or dict[str, Dataset] If a single measure is defined, a single xarray dataset is returned. If several measures are defined, a dictionary mapping measure IDs to the corresponding result dataset is returned. """ exp.process(spp=spp, seed_state=seed_state) exp.postprocess() return exp.results if len(exp.results) > 1 else onedict_value(exp.results)