Source code for eradiate_disort._backend

# SPDX-FileCopyrightText: 2026 Rayference
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import logging
from collections.abc import Generator, Hashable
from typing import TYPE_CHECKING

import attrs
import eradiate
import nanodisort as nd
import numpy as np
import tqdm.auto as tqdm
import xarray as xr
from eradiate import KernelContext, config
from eradiate.exceptions import UnsupportedModeError
from eradiate.experiments import AtmosphereExperiment
from eradiate.scenes.atmosphere import (
    HeterogeneousAtmosphere,
    HomogeneousAtmosphere,
    MolecularAtmosphere,
    ParticleLayer,
)
from eradiate.scenes.bsdfs import LambertianBSDF
from eradiate.scenes.illumination import DirectionalIllumination
from eradiate.units import unit_registry as ureg

from ._measurements import DisortMeasure
from ._phase import get_phase, get_pmom
from ._pipeline import build_disort_pipeline, compute_measures_info

if TYPE_CHECKING:
    from eradiate.spectral import CKDSpectralGrid, MonoSpectralGrid, SpectralIndex

logger = logging.getLogger(__name__)


[docs] @attrs.define class DisortBackend: """ Eradiate DISORT backend. This class implements an experimental Eradiate radiometric backend that uses the CDISORT implementation of the DISORT algorithm. It supports 1D scenes with atmospheres featuring an arbitrary number of components and can generally be used as a fast alternative to the Monte Carlo ray tracing backend on plane-parallel geometries. Parameters ---------- nstr : int, default: 16 Number of streams (angular discretization). nmom : int, default: 16 Number of Legendre moments used to represent scattering distributions (phase functions and BRDFs). verbose : bool, default: False If ``False``, silence CDISORT terminal output. intensity_correction : {"nakajima_tanaka", "buras_emde"}, default: "buras_emde" Intensity correction method. ``"nakajima_tanaka"`` uses only Legendre moments and is always available. ``"buras_emde"`` additionally requires the actual phase function values and is more accurate for sharply peaked phase functions. See Also -------- nanodisort.DisortState """ nstr: int = attrs.field(default=16, repr=False) nmom: int = attrs.field(default=16, repr=False) intensity_correction: str = attrs.field( default="buras_emde", validator=attrs.validators.in_(["nakajima_tanaka", "buras_emde"]), repr=False, ) verbose: bool = attrs.field(default=False, repr=False) _state: nd.DisortState = attrs.field(factory=nd.DisortState, repr=False) _results: dict[Hashable, dict] = attrs.field(factory=dict, repr=False) # Populated at the end of process(); consumed by postprocess() _postprocess_ctx: dict = attrs.field(factory=dict, repr=False) _name: str = "CDISORT"
[docs] def validate(self, exp: AtmosphereExperiment): """ Check internal state consistency and compatibility with the passed Experiment configuration. Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. Raises ------ TypeError If validation fails. """ # Illumination: only directional illumination is supported if not isinstance(exp.illumination, DirectionalIllumination): raise TypeError( "EradiateDisortBackend requires a DirectionalIllumination, " f"got {type(exp.illumination).__name__}" ) # Measures: only DisortMeasure instances for measure in exp.measures: if not isinstance(measure, DisortMeasure): raise TypeError( "EradiateDisortBackend requires DisortMeasure instances, " f"got {type(measure).__name__}" ) # At most one measure with a direction layout (DISORT has a single umu/phi grid) radiance_measures = [m for m in exp.measures if m.direction_layout is not None] if len(radiance_measures) > 1: raise TypeError( "EradiateDisortBackend supports at most one radiance-mode " f"DisortMeasure per run (found {len(radiance_measures)})" ) # Surface: only Lambertian (diffuse) BSDF is supported if not isinstance(exp.surface.bsdf, LambertianBSDF): raise TypeError( "EradiateDisortBackend requires a Lambertian surface BSDF, " f"got {type(exp.surface.bsdf).__name__}" ) # Atmosphere: only heterogeneous atmospheres are supported allowed = ( MolecularAtmosphere, ParticleLayer, HeterogeneousAtmosphere, HomogeneousAtmosphere, ) if exp.atmosphere is not None and not isinstance(exp.atmosphere, allowed): raise TypeError( "EradiateDisortBackend requires one of " f"[{', '.join([f'{x.__name__}' for x in allowed])}], " f"got {type(exp.atmosphere).__name__}" )
def _setup_global( self, exp: AtmosphereExperiment, ref_ctx: KernelContext | None = None ) -> dict: """ Perform global setup that does not depend on the spectral dimension. Called once at the beginning of :meth:`.process`. Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. ref_ctx : KernelContext, optional Reference spectral context used for initialization. Required only when using the Buras-Emde intensity correction. Returns ------- dict Run context with transient processing state. Keys: - ``has_radiance`` : bool - ``active_measures`` : list - ``mes_mu`` : ndarray — sorted unique cosines for DISORT - ``mes_phi`` : ndarray — azimuth angles [deg] for DISORT - ``ill_mu`` : float — illumination cosine - ``ill_phi`` : float — illumination azimuth [deg] """ logger.debug("EradiateDisortBackend: Global setup") ds = self._state # Classify active measures measures = list(exp.measures) has_radiance = any(m.direction_layout is not None for m in measures) # Illumination angles illumination = exp.illumination ill_mu = np.cos(illumination.zenith.m_as("rad")) # DISORT phi0 is the azimuth of the beam's travel direction; Eradiate's # illumination.azimuth is the sun's source direction — opposite by 180°. ill_phi = (illumination.azimuth.m_as("deg") + 180.0) % 360.0 # Control flags ds.quiet = not self.verbose ds.usrtau = True ds.lamber = True ds.planck = False ds.usrang = has_radiance ds.onlyfl = not has_radiance # Intensity correction method if self.intensity_correction == "buras_emde": if ref_ctx is None: raise RuntimeError( "Buras-Emde correction requires a reference spectral context " "to size the phase angle grid before allocation." ) mu_grid, _ = get_phase(exp.atmosphere, self.nstr, ref_ctx) # +2 accounts for sentinel points added at both ends in _setup_spectral ds.nphase = len(mu_grid) + 2 ds.intensity_correction = True ds.old_intensity_correction = False else: # "nakajima_tanaka" ds.nphase = 0 ds.intensity_correction = True ds.old_intensity_correction = True # Atmosphere layer count ds.nlyr = exp.atmosphere.geometry.zgrid.n_layers if exp.atmosphere else 1 ds.nstr = self.nstr ds.nmom = self.nmom # Viewing angle setup (for radiance measure) if has_radiance: rad_measure = next(m for m in measures if m.direction_layout is not None) mes_angles = rad_measure.direction_layout.angles mask = mes_angles[:, 0] < 0 mes_angles = mes_angles.copy() mes_angles[mask, 0] *= -1.0 mes_angles[mask, 1] = mes_angles[mask, 1] + 180.0 * ureg.deg mes_angles[:, 1] %= 360.0 * ureg.deg mes_mu = np.sort(np.unique(np.cos(mes_angles[:, 0].m_as("rad")))) mes_phi = np.sort(np.unique(mes_angles[:, 1].m_as("deg"))) ds.numu = len(mes_mu) ds.nphi = len(mes_phi) else: # DISORT needs at least one umu/phi even with onlyfl=True ds.numu = 1 ds.nphi = 1 mes_mu = np.array([1.0]) # nadir mes_phi = np.array([0.0]) return { "has_radiance": has_radiance, "active_measures": measures, "mes_mu": mes_mu, "mes_phi": mes_phi, "ill_mu": ill_mu, "ill_phi": ill_phi, } def _setup_spectral( self, exp: AtmosphereExperiment, ctx: KernelContext, run_ctx: dict, first_call: bool = False, ) -> None: """ Perform setup that depends on the spectral dimension. Called at each iteration of the spectral loop. Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. ctx : KernelContext Current spectral context. run_ctx : dict Run context produced by :meth:`._setup_global`. first_call : bool If ``True``, allocates DISORT memory (must be called exactly once, before :meth:`._solve`). """ logger.debug("EradiateDisortBackend: Spectral loop setup") ds = self._state atmosphere = exp.atmosphere # --- Compute values (no array assignments yet — memory may not be allocated) if atmosphere is not None: h = atmosphere.geometry.zgrid.layer_height sigma_t = atmosphere.eval_sigma_t(ctx.si) tau_btt = np.atleast_1d((sigma_t * h).m_as("dimensionless")) _dither = 100.0 * np.finfo(float).eps ssalb = np.atleast_1d(atmosphere.eval_albedo(ctx.si).m_as("dimensionless")) ssalb = np.minimum(ssalb, 1.0 - _dither) zgrid = atmosphere.geometry.zgrid # Homogeneous atmospheres return scalar optical properties; broadcast # them across all layers so the arrays match DISORT's nlyr. n_layers = zgrid.n_layers if tau_btt.shape[0] != n_layers: tau_btt = np.broadcast_to(tau_btt, (n_layers,)).copy() if ssalb.shape[0] != n_layers: ssalb = np.broadcast_to(ssalb, (n_layers,)).copy() else: tau_btt = np.array([0.0]) ssalb = np.array([0.0]) # Minimal two-level zgrid for altitude resolution when no atmosphere from eradiate.radprops import ZGrid zgrid = ZGrid([0.0, 1.0] * ureg.m) pmom = get_pmom(atmosphere, ds.nmom, ctx)[:, ::-1] buras_emde_arrays = None if self.intensity_correction == "buras_emde" and atmosphere is not None: mu_grid, phase_tbt = get_phase(atmosphere, ds.nstr, ctx) _eps = 1e-10 mu_padded = np.concatenate([[-1.0 - _eps], mu_grid, [1.0 + _eps]]) phase_padded = np.hstack([phase_tbt[:, :1], phase_tbt, phase_tbt[:, -1:]]) buras_emde_arrays = (mu_padded, np.ascontiguousarray(phase_padded[::-1, :])) # Merged utau must be computed before allocate() so ntau is known merged_utau, measures_info = compute_measures_info( run_ctx["active_measures"], tau_btt, zgrid, ) irradiance = exp.illumination.irradiance.eval(ctx.si).m_as("W/m^2/nm") albedo = exp.surface.bsdf.reflectance.eval(ctx.si).m_as("dimensionless") # --- Allocate on first call, then assign all arrays if first_call: ds.ntau = len(merged_utau) ds.allocate() # Save structural info for post-processing (fixed across spectral loop) self._postprocess_ctx["measures_info"] = measures_info if atmosphere is not None: ds.dtauc = tau_btt[::-1] # DISORT expects top-to-bottom ds.ssalb = ssalb[::-1] else: ds.dtauc = tau_btt ds.ssalb = ssalb ds.pmom = pmom if buras_emde_arrays is not None: ds.mu_phase, ds.phase = buras_emde_arrays # A flux-only solve (usrang=False) makes CDISORT overwrite numu with # nstr internally; restore the user dimensions before re-assigning the # angular arrays on subsequent spectral iterations. ds.numu = len(run_ctx["mes_mu"]) ds.nphi = len(run_ctx["mes_phi"]) ds.umu = run_ctx["mes_mu"] ds.phi = run_ctx["mes_phi"] ds.utau = merged_utau ds.fbeam = irradiance ds.umu0 = run_ctx["ill_mu"] ds.phi0 = run_ctx["ill_phi"] ds.albedo = albedo ds.fisot = 0.0 ds.fluor = 0.0 def _solve(self) -> None: """Run the DISORT solver.""" logger.debug("EradiateDisortBackend: Running DISORT solver") self._state.solve() def _collect_results(self, run_ctx: dict) -> dict: """Collect all relevant outputs from the DISORT state.""" ds = self._state return { "uu": np.array(ds.uu) if run_ctx["has_radiance"] else None, "rfldir": np.array(ds.rfldir), "rfldn": np.array(ds.rfldn), "flup": np.array(ds.flup), "dfdt": np.array(ds.dfdt), "uavg": np.array(ds.uavg), "uavgdn": np.array(ds.uavgdn), "uavgup": np.array(ds.uavgup), "uavgso": np.array(ds.uavgso), } def _get_spectral_indices( self, exp: AtmosphereExperiment, measure_index: int ) -> Generator[SpectralIndex]: if eradiate.mode().is_mono: spectral_grid: MonoSpectralGrid = exp.spectral_grids[measure_index] def generator(): yield from spectral_grid.walk_indices() elif eradiate.mode().is_ckd: spectral_grid: CKDSpectralGrid = exp.spectral_grids[measure_index] quad_config = exp.ckd_quad_config try: abs_db = exp.atmosphere.abs_db except ( AttributeError ): # There is either no atmosphere or no absorption database abs_db = None def generator(): yield from spectral_grid.walk_indices(quad_config, abs_db) else: raise UnsupportedModeError yield from generator() def _get_contexts( self, exp: AtmosphereExperiment, measure_index: int = 0 ) -> list[KernelContext]: """ Return the list of spectral contexts for the given measure. Uses the spectral grid directly (no Mitsuba kernel required). """ return [ KernelContext(si) for si in self._get_spectral_indices(exp, measure_index) ]
[docs] def process( self, exp: AtmosphereExperiment, measure: None | int | str = None ) -> None: """ Run the processing step for a given Experiment configuration. All measures in the experiment are processed together in a single spectral loop. The spectral grid is taken from ``measure`` (default: the first measure). Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. measure : int or str, optional Index or string ID of the measure whose spectral grid drives the loop. Defaults to the first measure. """ exp.init() if measure is None: measure_index = 0 else: m = exp.measures.resolve(measure) measure_index = exp.measures.get_index(m.id) ctxs = self._get_contexts(exp, measure_index) ref_ctx = ctxs[0] if ctxs else None run_ctx = self._setup_global(exp, ref_ctx=ref_ctx) results = {} with tqdm.tqdm( initial=0, total=len(ctxs), unit_scale=1.0, leave=True, bar_format="{desc}{n:g}/{total:g}|{bar}| {elapsed}, ETA={remaining}", disable=(config.settings.progress < config.ProgressLevel.SPECTRAL_LOOP) or len(ctxs) <= 1, ) as pbar: for i, ctx in enumerate(ctxs): pbar.set_description( f"Spectral loop — {self._name} [{ctx.index_formatted}]" ) self._setup_spectral(exp, ctx, run_ctx, first_call=(i == 0)) self._solve() results[ctx.si.as_hashable] = self._collect_results(run_ctx) pbar.update() self._results = results # Store everything postprocess() needs self._postprocess_ctx.update( { "geometry": { "umu": np.array(self._state.umu), "phi": np.array(self._state.phi), "umu0": float(self._state.umu0), "phi0": float(self._state.phi0), }, "spectral_grid": exp.spectral_grids[measure_index], "ckd_quads": exp.ckd_quads[measure_index], } )
[docs] def postprocess( self, exp: AtmosphereExperiment, measure: None | int | str = None ) -> xr.DataTree: """ Run the postprocessing step and return a DataTree. Returns a DataTree with one subtree per measure, keyed by measure ID. The ``measure`` argument is accepted for API compatibility but ignored; all measures are always included in the output. Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. measure : int or str, optional Ignored. Present for API compatibility. Returns ------- DataTree One subtree per measure, keyed by ``/{measure.id}/``. """ mode = eradiate.get_mode() if mode.is_mono: mode_id = "mono" elif mode.is_ckd: mode_id = "ckd" else: raise UnsupportedModeError ctx = self._postprocess_ctx pipeline = build_disort_pipeline() result = pipeline.execute( outputs=["datatree"], inputs={ "raw_results": self._results, "mode": mode_id, "spectral_grid": ctx["spectral_grid"], "ckd_quads": ctx["ckd_quads"], "geometry": ctx["geometry"], "measures_info": ctx["measures_info"], }, ) return result["datatree"]
[docs] def run( self, exp: AtmosphereExperiment, measure: None | int | str = None ) -> xr.DataTree: """ Run validation, processing, and postprocessing in sequence. Parameters ---------- exp : AtmosphereExperiment Processed experiment configuration. measure : int or str, optional Index or string ID of the measure whose spectral grid drives the processing loop. Defaults to the first measure. Returns ------- DataTree Post-processed results, one subtree per measure. """ self.validate(exp) self.process(exp, measure=measure) return self.postprocess(exp, measure=measure)