Source code for eradiate.scenes.phase._blend

from __future__ import annotations

import typing as t
from collections import abc as cabc

import attrs
import mitsuba as mi
import numpy as np

from ._core import PhaseFunction, phase_function_factory
from ..core import traverse
from ..geometry import PlaneParallelGeometry, SceneGeometry, SphericalShellGeometry
from ...attrs import documented
from ...contexts import KernelContext
from ...kernel import InitParameter, UpdateParameter
from ...spectral.index import SpectralIndex
from ...util.misc import cache_by_id


[docs] @attrs.define(eq=False, slots=False) class BlendPhaseFunction(PhaseFunction): """ Blended phase function [``blend_phase``]. This phase function aggregates two or more sub-phase functions (*components*) and blends them based on its `weights` parameter. Weights are usually based on the associated medium's scattering coefficient. """ components: list[PhaseFunction] = documented( attrs.field( converter=lambda x: [phase_function_factory.convert(y) for y in x], validator=attrs.validators.deep_iterable( attrs.validators.instance_of(PhaseFunction) ), kw_only=True, ), type="list of :class:`.PhaseFunction`", init_type="list of :class:`.PhaseFunction` or list of dict", doc="List of components (at least two). This parameter has not default.", ) @components.validator def _components_validator(self, attribute, value): if not len(value) > 1: raise ValueError( f"while validating {attribute.name}: BlendPhaseFunction must " "have at least two components" ) weights: (np.ndarray | list[t.Callable[[KernelContext], np.ndarray]]) = documented( attrs.field( converter=lambda x: x if callable(x[0]) else np.array(x, dtype=np.float64), kw_only=True, ), type="ndarray or list of callables", init_type="array-like or list of callables", doc="List of weights associated with each component. Weights may be " "numerical values; in that case, they must be of shape (n,) or (n, m), " "where n is the number of components and m the number of cells along " "the atmosphere's vertical axis. Alternatively, weights may be " "callables that take a :class:`.KernelContext` as argument and " "return an array of shape (n, m). " "This parameter is required and has no default.", ) @weights.validator def _weights_validator(self, attribute, value): if isinstance(value, np.ndarray): if value.ndim == 0 or value.ndim > 2: raise ValueError( f"while validating '{attribute.name}': array must have 1 or 2 " f"dimensions, got {value.ndim}" ) if not value.shape[0] == len(self.components): raise ValueError( f"while validating '{attribute.name}': array must have shape " "(n,) or (n, m) where n is the number of components; got " f"{value.shape}" ) elif isinstance(value, cabc.Sequence): if not len(value) == len(self.components): raise ValueError( f"while validating '{attribute.name}': weight and component " "lists must have the same length" ) geometry: SceneGeometry | None = documented( attrs.field( default=None, converter=attrs.converters.optional(SceneGeometry.convert), validator=attrs.validators.optional( attrs.validators.instance_of(SceneGeometry) ), ), doc="Parameters defining the basic geometry of the scene. If unset, " "the volume textures defining component weights will be assigned " "defaults likely unsuitable for atmosphere construction.", type=".SceneGeometry or None", init_type=".SceneGeometry or dict or str, optional", default="None", )
[docs] def update(self) -> None: super().update() # Synchronize geometries for component in self.components: component.update() if isinstance(component, BlendPhaseFunction): component.geometry = self.geometry
@cache_by_id def _eval_conditional_weights_impl(self, si: SpectralIndex) -> np.ndarray: """ Memoised weight evaluation, used if weights are defined as callables. """ n_comp = len(self.components) if isinstance(self.weights, list): weights = np.array([w(si) for w in self.weights], dtype=np.float64) else: # if isinstance(self.weights, np.ndarray): weights = np.array(self.weights, dtype=np.float64) if weights.ndim < 2: weights = weights.reshape((-1, 1)) result = np.zeros((n_comp - 1, *weights.shape[1:]), dtype=np.float64) # Compute conditional weights for i in range(n_comp - 1): # Normalize weights weights_sum = weights[i:, ...].sum(axis=0, keepdims=True) weights_normalized = np.divide( weights[i:, ...], weights_sum, where=weights_sum != 0.0, out=np.zeros_like(weights[i:, ...]), ) # Aggregate weights of all components except the first one result[i] = weights_normalized[1:, ...].sum(axis=0, keepdims=True) return result
[docs] def eval_conditional_weights( self, si: SpectralIndex, n_component: int | list[int] | None = None, ) -> np.ndarray: """ Evaluate the conditional weights of specified Mitsuba phase function components. Parameters ---------- si : :class:`.SpectralIndex` Spectral context. n_component : int or list of int, optional The index of the Mitsuba phase function component for which the conditional weight should be evaluated. If ``None``, the conditional weights of all components will be evaluated. Returns ------- ndarray Conditional weights of the specified components as an array of shape (N, M) where n is the number of components and m the number of cells along the atmosphere's vertical axis. """ if n_component is None: n_component = range(len(self.components) - 1) elif isinstance(n_component, int): n_component = [n_component] # Compute normalized component weights (cached until call with different # context) weights = self._eval_conditional_weights_impl(si) # Return selected components return weights[n_component, ...]
@property def template(self) -> dict: result = {"type": "blendphase"} for i in range(len(self.components) - 1): prefix = "phase_1." * i # Add components template, _ = traverse(self.components[i]) result.update( { **{f"{prefix}phase_0.{k}": v for k, v in template.items()}, f"{prefix}phase_1.type": "blendphase", } ) # Assign conditional weight to second component if self.geometry is None or isinstance( self.geometry, PlaneParallelGeometry ): # Note: This defines a partial and evaluates the component index. # Passing i as the kwarg default value is essential to force the # dereferencing of the loop variable. def eval_conditional_weights(ctx: KernelContext, n_component=i): return mi.VolumeGrid( np.reshape( self.eval_conditional_weights(ctx.si, n_component), (-1, 1, 1), # Mind dim ordering! (C-style, i.e. zyx) ).astype(np.float32) ) result[f"{prefix}weight.type"] = "gridvolume" result[f"{prefix}weight.grid"] = InitParameter(eval_conditional_weights) if self.geometry is not None: result[ f"{prefix}weight.to_world" ] = self.geometry.atmosphere_volume_to_world elif isinstance(self.geometry, SphericalShellGeometry): # Same comment as above def eval_conditional_weights(ctx: KernelContext, n_component=i): return mi.VolumeGrid( np.reshape( self.eval_conditional_weights(ctx.si, n_component), (1, 1, -1), # Mind dim ordering! (C-style, i.e. zyx) ).astype(np.float32) ) result[f"{prefix}weight.type"] = "sphericalcoordsvolume" result[f"{prefix}weight.volume.type"] = "gridvolume" result[f"{prefix}weight.volume.grid"] = InitParameter( eval_conditional_weights ) result[ f"{prefix}weight.to_world" ] = self.geometry.atmosphere_volume_to_world result[f"{prefix}weight.rmin"] = self.geometry.atmosphere_volume_rmin else: raise ValueError( f"unhandled scene geometry type '{type(self.geometry).__name__}'" ) else: template, _ = traverse(self.components[-1]) result.update({**{f"{prefix}phase_1.{k}": v for k, v in template.items()}}) return result @property def params(self) -> dict[str, UpdateParameter]: result = {} for i in range(len(self.components) - 1): prefix = "phase_1." * i # Add components _, params = traverse(self.components[i]) result.update( { **{f"{prefix}phase_0.{k}": v for k, v in params.items()}, } ) if self.geometry is None or isinstance( self.geometry, PlaneParallelGeometry ): # Note: This defines a partial and evaluates the component index. # Passing i as the kwarg default value is essential to force the # dereferencing of the loop variable. def eval_conditional_weights(ctx: KernelContext, n_component=i): return np.reshape( self.eval_conditional_weights(ctx.si, n_component), (-1, 1, 1, 1), # Mind dim ordering! (C-style, i.e. zyxc) ).astype(np.float32) # Assign conditional weight to second component result[f"{prefix}weight.data"] = UpdateParameter( eval_conditional_weights, UpdateParameter.Flags.SPECTRAL, ) elif isinstance(self.geometry, SphericalShellGeometry): # Same comment as above def eval_conditional_weights(ctx: KernelContext, n_component=i): return np.reshape( self.eval_conditional_weights(ctx.si, n_component), (1, 1, -1, 1), # Mind dim ordering! (C-style, i.e. zyxc) ).astype(np.float32) # Assign conditional weight to second component result[f"{prefix}weight.volume.data"] = UpdateParameter( eval_conditional_weights, UpdateParameter.Flags.SPECTRAL, ) else: raise NotImplementedError else: _, params = traverse(self.components[-1]) result.update({**{f"{prefix}phase_1.{k}": v for k, v in params.items()}}) return result