Source code for eradiate.scenes.phase._tabulated

import attrs
import numpy as np
import pint
import xarray as xr

import eradiate

from ._core import PhaseFunction
from ..core import KernelDict
from ...attrs import documented, parse_docs
from ...ckd import Bindex
from ...contexts import KernelDictContext, SpectralContext
from ...exceptions import UnsupportedModeError
from ...units import unit_registry as ureg


def _ensure_magnitude_array(q: pint.Quantity) -> pint.Quantity:
    """Convert to Quantity whose magnitude is an array"""
    if np.isscalar(q.magnitude):
        return ureg.Quantity(np.atleast_1d(q.magnitude), q.units)
    else:
        return q


def _convert_data(da: xr.DataArray) -> xr.DataArray:
    """
    If the coordinate mu is not strictly increasing, return a reindexed
    copy of the DataArray where it is.
    Else, return the input DataArray.
    """
    if np.all(da["mu"].values[1:] > da["mu"].mu.values[:-1]):
        return da
    elif np.all(da["mu"].values[1:] < da["mu"].mu.values[:-1]):
        return da.reindex(mu=da.mu[::-1])
    else:
        return da.reindex(mu=np.sort(da.mu.values))


def _validate_data(instance, attribute, value):
    """
    Bounds of mu coordinate must be [-1.0, 1.0].
    """
    mu_min = value["mu"].values.min()
    mu_max = value["mu"].values.max()

    if not (mu_min == -1.0 and mu_max == 1.0):
        raise ValueError(
            f"Coordinate 'mu' bounds must be -1.0, 1.0 (got {mu_min}, {mu_max})."
        )


[docs]@parse_docs @attrs.define class TabulatedPhaseFunction(PhaseFunction): r""" Tabulated phase function [``tab_phase``]. A lookup table-based phase function. The ``data`` field is a :class:`~xarray.DataArray` with wavelength and angular dimensions. Notes ----- * The :math:`\mu` coordinate must cover the :math:`[-1, 1]` interval but there is no constraint on value ordering or spacing. In particular, irregular :math:`\mu` grids are supported. * For optimal performance, providing phase function data on a regular, sorted :math:`\mu` grid is recommended. """ data: xr.DataArray = documented( attrs.field( converter=_convert_data, validator=[attrs.validators.instance_of(xr.DataArray), _validate_data], kw_only=True, ), type="DataArray", doc="Value table as a data array with wavelength (``w``), scattering " "angle cosine (``mu``), and scattering phase matrix row " "(``i``) and column (``j``) indices (integer) as coordinates. " "This parameter has no default.", )
[docs] def eval(self, spectral_ctx: SpectralContext) -> np.ndarray: r""" Evaluate phase function based on a spectral context. This method dispatches evaluation to specialised methods depending on the active mode. Parameters ---------- spectral_ctx : :class:`.SpectralContext` A spectral context data structure containing relevant spectral parameters (*e.g.* wavelength in monochromatic mode, bin and quadrature point index in CKD mode). Returns ------- ndarray Evaluated phase function as a 1D array. """ if eradiate.mode().is_mono: return self.eval_mono(spectral_ctx.wavelength).squeeze() elif eradiate.mode().is_ckd: return self.eval_ckd(spectral_ctx.bindex).squeeze() else: raise UnsupportedModeError(supported=("monochromatic", "ckd"))
[docs] def eval_mono(self, w: pint.Quantity) -> np.ndarray: """ Evaluate phase function in monochromatic modes. Parameters ---------- w : :class:`pint.Quantity` Wavelength values at which the phase function is to be evaluated. Returns ------- ndarray Evaluated phase function as a 1D or 2D array depending on the shape of `w` (angle dimension comes last). """ w_units = self.data.w.attrs["units"] return ( self.data.isel(i=0, j=0) .interp( w=w.m_as(w_units), kwargs=dict(bounds_error=True), ) .data )
[docs] def eval_ckd(self, *bindexes: Bindex) -> np.ndarray: """ Evaluate phase function in CKD modes. Parameters ---------- *bindexes : :class:`.Bindex` One or several CKD bindexes for which to evaluate the phase function. Returns ------- ndarray Evaluated phase function as a 1D or 2D array depending on the number of passed bindexes (angle dimension comes last). """ w_units = ureg(self.data.w.units) w = [bindex.bin.wcenter.m_as(w_units) for bindex in bindexes] * w_units return self.eval_mono(w)
[docs] def kernel_dict(self, ctx: KernelDictContext) -> KernelDict: # Evaluate phase function phase_values = self.eval(ctx.spectral_ctx) # Retrieve mu grid mu = self.data.mu.values dmu = mu[1:] - mu[:-1] # Create kernel dict if np.allclose(dmu, dmu[0]): # mu grid is regular return KernelDict( { self.id: { "type": "tabphase", "values": ",".join(map(str, phase_values)), } } ) else: # mu grid is irregular return KernelDict( { self.id: { "type": "tabphase_irregular", "values": ",".join(map(str, phase_values)), "nodes": ",".join(map(str, mu)), } } )