Source code for eradiate.experiments._core

from __future__ import annotations

import logging
import typing as t
from abc import ABC, abstractmethod
from datetime import datetime

import attrs
import mitsuba as mi
import pinttr
import xarray as xr

import eradiate

from .. import pipelines, validators
from ..attrs import AUTO, documented, parse_docs
from ..contexts import KernelContext
from ..kernel import MitsubaObjectWrapper, mi_render, mi_traverse
from ..pipelines import Pipeline
from ..rng import SeedState
from ..scenes.core import Scene, SceneElement, get_factory, traverse
from ..scenes.illumination import (
    ConstantIllumination,
    DirectionalIllumination,
    illumination_factory,
)
from ..scenes.integrators import Integrator, PathIntegrator, integrator_factory
from ..scenes.measure import (
    DistantFluxMeasure,
    DistantMeasure,
    HemisphericalDistantMeasure,
    Measure,
    MultiDistantMeasure,
    measure_factory,
)
from ..scenes.spectra import InterpolatedSpectrum
from ..spectral.ckd import BinSet
from ..spectral.index import CKDSpectralIndex, MonoSpectralIndex, SpectralIndex
from ..spectral.mono import WavelengthSet
from ..util.misc import deduplicate_sorted, onedict_value

logger = logging.getLogger(__name__)


def _convert_spectral_set(value):
    if value is AUTO:
        if eradiate.mode().is_ckd:
            return BinSet.default()
        elif eradiate.mode().is_mono:
            return WavelengthSet.default()
        else:
            raise NotImplementedError(f"unsupported mode: {eradiate.mode().id}")
    else:
        return value


[docs]@parse_docs @attrs.define class Experiment(ABC): """ Abstract base class for all Eradiate experiments. An experiment consists of a high-level scene specification parametrized by natural user input, a processing and post-processing pipeline, and a result storage data structure. """ # Internal Mitsuba scene. This member is not set by the end-user, but rather # by the Experiment itself during initialization. mi_scene: MitsubaObjectWrapper | None = attrs.field( default=None, repr=False, init=False, ) measures: list[Measure] = documented( attrs.field( factory=lambda: [MultiDistantMeasure()], converter=lambda value: [ measure_factory.convert(x) for x in pinttr.util.always_iterable(value) ] if not isinstance(value, dict) else [measure_factory.convert(value)], validator=attrs.validators.deep_iterable( member_validator=attrs.validators.instance_of(Measure) ), ), doc="List of measure specifications. The passed list may contain " "dictionaries, which will be interpreted by " ":data:`.measure_factory`. " "Optionally, a single :class:`.Measure` or dictionary specification " "may be passed and will automatically be wrapped into a list.", type="list of :class:`.Measure`", init_type="list of :class:`.Measure` or list of dict or " ":class:`.Measure` or dict", default=":class:`MultiDistantMeasure() <.MultiDistantMeasure>`", ) _integrator: Integrator = documented( attrs.field( factory=PathIntegrator, converter=integrator_factory.convert, validator=attrs.validators.instance_of(Integrator), ), doc="Monte Carlo integration algorithm specification. " "This parameter can be specified as a dictionary which will be " "interpreted by :data:`.integrator_factory`.", type=":class:`.Integrator`", init_type=":class:`.Integrator` or dict", default=":class:`PathIntegrator() <.PathIntegrator>`", ) @property def integrator(self) -> Integrator: """ :class:`.Integrator`: Integrator used to solve the radiative transfer equation. """ return self._integrator _results: dict[str, xr.Dataset] = attrs.field(factory=dict, repr=False) @property def results(self) -> dict[str, xr.Dataset]: """ Post-processed simulation results. Returns ------- dict[str, Dataset] Dictionary mapping measure IDs to xarray datasets. """ return self._results default_spectral_set: BinSet | WavelengthSet = documented( attrs.field( default=AUTO, validator=validators.auto_or( attrs.validators.instance_of((BinSet, WavelengthSet)) ), converter=_convert_spectral_set, ), doc="Default spectral set. This attribute is used to set the " "default value for :attr:`spectral_set`." "If the value is :data:`AUTO`, the default spectral set is selected " "based on the active mode. Otherwise, the value must be a " ":class:`.BinSet` or :class:`.WavelengthSet` instance.", type=":class:`.BinSet` or :class:`.WavelengthSet`", init_type=":class:`.BinSet` or :class:`.WavelengthSet` or :data:`AUTO`", default=":data:`AUTO`", ) # Mapping of measure index and WavelengthSet or BinSet depending on active # mode. This attribute is set by the '_normalize_spectral()' method. _spectral_set = attrs.field(init=False) @property def spectral_set(self) -> WavelengthSet | BinSet: return self._spectral_set def __attrs_post_init__(self): self._normalize_spectral() def _normalize_spectral(self) -> None: """ Setup spectral set based on active mode. """ spectral_set = self.default_spectral_set atmosphere = getattr(self, "atmosphere", None) if atmosphere: atmosphere_spectral_set = atmosphere.spectral_set if atmosphere_spectral_set is not None: spectral_set = atmosphere_spectral_set # try: # atmosphere = self.atmosphere # if atmosphere and atmosphere.spectral_set: # spectral_set = atmosphere.spectral_set # except AttributeError: # pass self._spectral_set = { i: measure.srf.select_in(spectral_set) for i, measure in enumerate(self.measures) }
[docs] def clear(self) -> None: """ Clear previous experiment results and reset internal state. """ self.results.clear() for measure in self.measures: measure.mi_results.clear()
[docs] @abstractmethod def init(self) -> None: """ Generate kernel dictionary and initialise Mitsuba scene. """ pass
[docs] @abstractmethod def process( self, spp: int = 0, seed_state: SeedState | None = None, ) -> None: """ Run simulation and collect raw results. Parameters ---------- spp : int, optional Sample count. If set to 0, the value set in the original scene definition takes precedence. seed_state : :class:`.SeedState`, optional Seed state used to generate seeds to initialize Mitsuba's RNG at every iteration of the parametric loop. If unset, Eradiate's :attr:`root seed state <.root_seed_state>` is used. """ pass
[docs] @abstractmethod def postprocess(self) -> None: """ Post-process raw results and store them in :attr:`results`. """ pass
[docs] @abstractmethod def pipeline(self, measure: Measure) -> Pipeline: """ Return the post-processing pipeline for a given measure. Parameters ---------- measure : .Measure Measure for which the pipeline is to be generated. Returns ------- .Pipeline """ pass
@property @abstractmethod def context_init(self) -> KernelContext: """ Return a single context used for scene initialization. """ pass @property @abstractmethod def contexts(self) -> list[KernelContext]: """ Return a list of contexts used for processing. """ pass
def _extra_objects_converter(value): result = {} for key, element_spec in value.items(): if isinstance(element_spec, dict): element_spec = element_spec.copy() element_type = element_spec.pop("factory") factory = get_factory(element_type) result[key] = factory.convert(element_spec) else: result[key] = element_spec return result
[docs]@parse_docs @attrs.define class EarthObservationExperiment(Experiment, ABC): """ Abstract based class for experiments illuminated by a distant directional emitter. """ extra_objects: dict[str, SceneElement] = documented( attrs.field( factory=dict, converter=_extra_objects_converter, validator=attrs.validators.deep_mapping( key_validator=attrs.validators.instance_of(str), value_validator=attrs.validators.instance_of(SceneElement), ), ), doc="Dictionary of extra objects to be added to the scene. " "The keys of this dictionary are used to identify the objects " "in the kernel dictionary.", type="dict", default="{}", ) illumination: DirectionalIllumination | ConstantIllumination = documented( attrs.field( factory=DirectionalIllumination, converter=illumination_factory.convert, validator=attrs.validators.instance_of( (DirectionalIllumination, ConstantIllumination) ), ), doc="Illumination specification. " "This parameter can be specified as a dictionary which will be " "interpreted by :data:`.illumination_factory`.", type=":class:`.DirectionalIllumination`", init_type=":class:`.DirectionalIllumination` or dict", default=":class:`DirectionalIllumination() <.DirectionalIllumination>`", ) def _dataset_metadata(self, measure: Measure) -> dict[str, str]: """ Generate additional metadata applied to dataset after post-processing. Parameters ---------- measure : :class:`.Measure` Measure for which the metadata is created. Returns ------- dict[str, str] Metadata to be attached to the produced dataset. """ return { "convention": "CF-1.10", "source": f"eradiate, version {eradiate.__version__}", "history": f"{datetime.utcnow().replace(microsecond=0).isoformat()}" f" - data creation - {self.__class__.__name__}.postprocess()", "references": "", }
[docs] def spectral_indices(self, measure_index: int) -> t.Generator[SpectralIndex]: """ Generate spectral indices for a given measure. Parameters ---------- measure_index : int Measure index for which spectral indices are generated. Yields ------ :class:`.SpectralIndex` Spectral index. """ if eradiate.mode().is_mono: generator = self.spectral_indices_mono elif eradiate.mode().is_ckd: generator = self.spectral_indices_ckd else: raise RuntimeError(f"unsupported mode '{eradiate.mode().id}'") yield from generator(measure_index)
def spectral_indices_mono( self, measure_index: int ) -> t.Generator[MonoSpectralIndex]: yield from self.spectral_set[measure_index].spectral_indices() def spectral_indices_ckd(self, measure_index: int) -> t.Generator[CKDSpectralIndex]: yield from self.spectral_set[measure_index].spectral_indices() @property @abstractmethod def _context_kwargs(self) -> dict[str, t.Any]: pass @property def context_init(self) -> KernelContext: return KernelContext(si=self.contexts[0].si, kwargs=self._context_kwargs) @property def contexts(self) -> list[KernelContext]: # Inherit docstring # Collect contexts from all measures sis = [] for measure_index, measure in enumerate(self.measures): _si = list(self.spectral_indices(measure_index)) sis.extend(_si) # Sort and remove duplicates key = { MonoSpectralIndex: lambda si: si.w.m, CKDSpectralIndex: lambda si: (si.w.m, si.g), }[type(sis[0])] sis = deduplicate_sorted( sorted(sis, key=key), cmp=lambda x, y: key(x) == key(y) ) kwargs = self._context_kwargs return [KernelContext(si, kwargs=kwargs) for si in sis] @property @abstractmethod def scene_objects(self) -> dict[str, SceneElement]: pass @property def scene(self) -> Scene: """ Return a scene object used for kernel dictionary template and parameter table generation. """ return Scene(objects={**self.scene_objects, **self.extra_objects})
[docs] def init(self): # Inherit docstring logger.info("Initializing kernel scene") kdict_template, umap_template = traverse(self.scene) try: self.mi_scene = mi_traverse( mi.load_dict(kdict_template.render(ctx=self.context_init)), umap_template=umap_template, ) except RuntimeError as e: raise RuntimeError(f"(while loading kernel scene dictionary){e}") from e # Remove unused elements from Mitsuba scene parameter table self.mi_scene.drop_parameters()
[docs] def process( self, spp: int = 0, seed_state: SeedState | None = None, ) -> None: # Inherit docstring # Set up Mitsuba scene if self.mi_scene is None: self.init() # Run Mitsuba for each context logger.info("Launching simulation") mi_results = mi_render( self.mi_scene, self.contexts, seed_state=seed_state, spp=spp, ) # Assign collected results to the appropriate measure sensor_to_measure: dict[str, Measure] = { measure.sensor_id: measure for measure in self.measures } for ctx_index, spectral_group_dict in mi_results.items(): for sensor_id, mi_bitmap in spectral_group_dict.items(): measure = sensor_to_measure[sensor_id] measure.mi_results[ctx_index] = { "bitmap": mi_bitmap, "spp": spp if spp > 0 else measure.spp, }
[docs] def postprocess(self, pipeline_kwargs: dict | None = None) -> None: # Inherit docstring logger.info("Post-processing results") measures = self.measures if pipeline_kwargs is None: pipeline_kwargs = {} # Apply pipelines for measure in measures: pipeline = self.pipeline(measure) # Collect measure results self._results[measure.id] = pipeline.transform( measure.mi_results, **pipeline_kwargs ) # Apply additional metadata self._results[measure.id].attrs.update(self._dataset_metadata(measure))
[docs] def pipeline(self, measure: Measure) -> Pipeline: measure_index = self.measures.index(measure) pipeline = pipelines.Pipeline() # Gather pipeline.add( "gather", pipelines.Gather(var=measure.var), ) # Aggregate if eradiate.mode().is_ckd: pipeline.add( "aggregate_ckd_quad", pipelines.AggregateCKDQuad( measure=measure, binset=self.spectral_set[measure_index], var=measure.var[0], ), ) if isinstance(measure, (DistantFluxMeasure,)): pipeline.add( "aggregate_radiosity", pipelines.AggregateRadiosity( sector_radiosity_var=measure.var[0], radiosity_var="radiosity", ), ) # Assemble pipeline.add( "add_illumination", pipelines.AddIllumination( illumination=self.illumination, measure=measure, irradiance_var="irradiance", ), ) if isinstance(measure, DistantMeasure): pipeline.add( "add_viewing_angles", pipelines.AddViewingAngles(measure=measure) ) if isinstance(measure.srf, InterpolatedSpectrum): pipeline.add( "add_srf", pipelines.AddSpectralResponseFunction(measure=measure), ) # Compute if isinstance(measure, (MultiDistantMeasure, HemisphericalDistantMeasure)): pipeline.add( "compute_reflectance", pipelines.ComputeReflectance( radiance_var="radiance", irradiance_var="irradiance", brdf_var="brdf", brf_var="brf", ), ) if eradiate.mode().is_ckd and isinstance(measure.srf, InterpolatedSpectrum): pipeline.add( "apply_srf", pipelines.ApplySpectralResponseFunction( measure=measure, vars=["radiance", "irradiance"], ), ) pipeline.add( "compute_reflectance_srf", pipelines.ComputeReflectance( radiance_var="radiance_srf", irradiance_var="irradiance_srf", brdf_var="brdf_srf", brf_var="brf_srf", ), ) elif isinstance(measure, (DistantFluxMeasure,)): pipeline.add( "compute_albedo", pipelines.ComputeAlbedo( radiosity_var="radiosity", irradiance_var="irradiance", albedo_var="albedo", ), ) if eradiate.mode().is_ckd and isinstance(measure.srf, InterpolatedSpectrum): pipeline.add( "apply_srf", pipelines.ApplySpectralResponseFunction( measure=measure, vars=["radiosity", "irradiance"], ), ) pipeline.add( "compute_albedo_srf", pipelines.ComputeAlbedo( radiosity_var="radiosity_srf", irradiance_var="irradiance_srf", albedo_var="albedo_srf", ), ) return pipeline
# ------------------------------------------------------------------------------ # Experiment runner # ------------------------------------------------------------------------------
[docs]def run( exp: Experiment, spp: int = 0, seed_state: SeedState | None = None, ) -> xr.Dataset | dict[str, xr.Dataset]: """ Run an Eradiate experiment. This function performs kernel scene assembly, runs the computation and post-processes the raw results. The output consists of one or several xarray datasets. Parameters ---------- exp : Experiment Reference to the experiment object which will be processed. spp : int, optional, default: 0 Optional parameter to override the number of samples per pixel for all computed measures. If set to 0, the configured value for each measure takes precedence. seed_state : :class:`.SeedState`, optional Seed state used to generate seeds to initialize Mitsuba's RNG at every iteration of the parametric loop. If unset, Eradiate's :attr:`root seed state <.root_seed_state>` is used. Returns ------- Dataset or dict[str, Dataset] If a single measure is defined, a single xarray dataset is returned. If several measures are defined, a dictionary mapping measure IDs to the corresponding result dataset is returned. """ exp.process(spp=spp, seed_state=seed_state) exp.postprocess() return exp.results if len(exp.results) > 1 else onedict_value(exp.results)