from __future__ import annotations
import itertools
import typing as t
from abc import ABC, abstractmethod
from functools import singledispatchmethod
import numpy as np
import numpy.typing as npt
import pint
import pinttrs
from pinttrs.util import ensure_units
from .ckd_quad import CKDQuadConfig, CKDQuadPolicy
from .index import CKDSpectralIndex, MonoSpectralIndex, SpectralIndex
from .response import BandSRF, DeltaSRF, SpectralResponseFunction, UniformSRF
from .. import converters
from .._mode import ModeFlag, SubtypeDispatcher
from ..attrs import define, documented
from ..constants import SPECTRAL_RANGE_MAX, SPECTRAL_RANGE_MIN
from ..quad import Quad
from ..radprops import AbsorptionDatabase, CKDAbsorptionDatabase, MonoAbsorptionDatabase
from ..units import unit_context_config as ucc
from ..units import unit_registry as ureg
from ..util.misc import deduplicate_sorted, summary_repr
# ------------------------------------------------------------------------------
# Class implementations
# ------------------------------------------------------------------------------
[docs]
@define
class SpectralGrid(ABC):
"""
Abstract interface for all spectral grids.
"""
subtypes = SubtypeDispatcher("SpectralGrid")
@property
@abstractmethod
def wavelengths(self):
"""
Convenience accessor to characteristic wavelengths of this spectral grid.
"""
pass
[docs]
@staticmethod
def default() -> SpectralGrid:
"""
Generate a default spectral grid depending on the active mode.
"""
cls = SpectralGrid.subtypes.resolve()
return cls.default()
[docs]
@staticmethod
def arange(
start: float | pint.Quantity,
stop: float | pint.Quantity,
step: float | pint.Quantity,
) -> SpectralGrid:
"""
Generate a spectral grid from equally-spaced wavelengths.
Parameters
----------
start : quantity or float
Central wavelength of the first bin. If a unitless value is passed,
it is interpreted in default wavelength units (usually nm).
stop : quantity or float
Wavelength after which bin generation stops. If a unitless value is
passed, it is interpreted in default wavelength units (usually nm).
step : quantity or float
Spectral bin size. If a unitless value is passed, it is interpreted
in default wavelength units (usually nm).
Returns
-------
SpectralGrid
Generated spectral grid.
"""
cls = SpectralGrid.subtypes.resolve()
return cls.arange(start, stop, step)
[docs]
@staticmethod
def from_absorption_database(abs_db: AbsorptionDatabase) -> SpectralGrid:
"""
Retrieve the spectral grid from an absorption database. The returned
type depends on the currently active mode.
"""
cls = SpectralGrid.subtypes.resolve()
return cls.from_absorption_database(abs_db)
[docs]
def select(self, srf) -> SpectralGrid:
"""
Select a subset of the spectral grid based on a spectral response
function.
Parameters
----------
srf
A value that is either a :class:`.SpectralResponseFunction` instance
or convertible to a :class:`.SpectralResponseFunction` by the
:meth:`.SpectralResponseFunction.convert` method.
Returns
-------
SpectralGrid
New spectral grid instance covering the extent of the filtering SRF.
Notes
-----
The implementation of this method uses single dispatch based on the type
of the ``srf`` parameter.
"""
# This function performs value conversion then calls the _select_impl()
# dispatching method.
srf = SpectralResponseFunction.convert(srf)
return self._select_impl(srf)
@abstractmethod
def _select_impl(self, srf: SpectralResponseFunction) -> SpectralGrid:
pass
[docs]
@abstractmethod
def merge(self, other: SpectralGrid) -> SpectralGrid:
"""
Merge two spectral grids, applying a boolean "OR" operation.
Parameters
----------
other : SpectralGrid
Other spectral, of the same type, to merge with the current one.
Returns
-------
SpectralGrid
A new spectral grid of the same type that merges the two.
"""
pass
[docs]
@abstractmethod
def walk_indices(self, **kwargs) -> t.Generator[SpectralIndex, None, None]:
"""
A generator that yields a sequence of spectral index values.
Yields
------
.SpectralIndex
Generated spectral index of a type aligned with the current active
mode.
"""
pass
[docs]
@SpectralGrid.subtypes.register(ModeFlag.SPECTRAL_MODE_MONO)
@define
class MonoSpectralGrid(SpectralGrid):
"""
A spectral grid consisting of discrete wavelengths, used in monochromatic
modes.
"""
_wavelengths: pint.Quantity = documented(
pinttrs.field(
units=ucc.deferred("wavelength"),
converter=[
pinttrs.converters.to_units(ucc.deferred("wavelength")),
converters.on_quantity(np.atleast_1d),
converters.on_quantity(lambda x: x.astype(np.float64)),
converters.on_quantity(np.unique),
converters.on_quantity(np.sort),
],
repr=summary_repr,
),
doc="Wavelengths.",
type="quantity",
init_type="quantity or array-like or float",
)
@property
def wavelengths(self):
# Inherit docstring
return self._wavelengths
def plot(self, ax, lw=0.5, alpha=1.0):
ax.vlines(self.wavelengths.m, 0, 1, lw=lw, alpha=alpha)
return ax
def _repr_html_(self):
import base64
import io
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(1, 1, figsize=(6, 1))
self.plot(ax)
ax.set_xlabel(f"Wavelength [{self.wavelengths.u:~P}]")
sns.despine(left=True)
ax.axes.get_yaxis().set_visible(False)
img = io.BytesIO()
fig.savefig(img, format="png", bbox_inches="tight")
plt.close(fig)
img.seek(0)
return (
"<img "
f'src="data:image/png;base64, {base64.b64encode(img.getvalue()).decode("utf-8")}" '
"/>"
)
[docs]
@staticmethod
def default() -> MonoSpectralGrid:
"""
Generate a default monochromatic spectral grid that covers the default
spectral range with 1 nm spacing.
"""
return MonoSpectralGrid(
wavelengths=np.arange(
SPECTRAL_RANGE_MIN.m_as(ureg.nm),
SPECTRAL_RANGE_MAX.m_as(ureg.nm) + 0.1,
1.0,
)
* ureg.nm
)
[docs]
@staticmethod
def arange(
start: float | pint.Quantity,
stop: float | pint.Quantity,
step: float | pint.Quantity,
) -> MonoSpectralGrid:
"""
Generate a spectral grid from equally-spaced wavelengths.
Parameters
----------
start : quantity or float
Central wavelength of the first bin. If a unitless value is passed,
it is interpreted in default wavelength units (usually nm).
stop : quantity or float
Wavelength after which bin generation stops. If a unitless value is
passed, it is interpreted in default wavelength units (usually nm).
step : quantity or float
Spectral bin size. If a unitless value is passed, it is interpreted
in default wavelength units (usually nm).
Returns
-------
MonoSpectralGrid
Generated spectral grid.
"""
w_u = ucc.get("wavelength")
start = ensure_units(start, w_u).m_as(w_u)
stop = ensure_units(stop, w_u).m_as(w_u)
step = ensure_units(step, w_u).m_as(w_u)
return MonoSpectralGrid(wavelengths=np.arange(start, stop, step) * w_u)
[docs]
@classmethod
def from_absorption_database(cls, abs_db: MonoAbsorptionDatabase):
"""
Retrieve the spectral grid from a monochromatic absorption database.
"""
if not isinstance(abs_db, MonoAbsorptionDatabase):
raise TypeError
w = abs_db.spectral_coverage.index.get_level_values(level=1).values * ureg.nm
return cls(wavelengths=w)
@singledispatchmethod
def _select_impl(self, srf: SpectralResponseFunction) -> MonoSpectralGrid:
# Inherit docstring
raise NotImplementedError(f"unsupported data type '{type(srf)}'")
@_select_impl.register
def _(self, srf: DeltaSRF):
# Pass SRF wavelengths through
return MonoSpectralGrid(wavelengths=srf.wavelengths)
@_select_impl.register
def _(self, srf: UniformSRF):
w_m = self.wavelengths.m
w_u = self.wavelengths.u
wmin_m, wmax_m = srf.wmin.m_as(w_u), srf.wmax.m_as(w_u)
w_selected_m = w_m[(w_m >= wmin_m) & (w_m <= wmax_m)]
return MonoSpectralGrid(wavelengths=w_selected_m * w_u)
@_select_impl.register
def _(self, srf: BandSRF):
# Select all wavelengths for which the SRF evaluates to a nonzero value
values = srf.eval(self.wavelengths)
w_selected = self.wavelengths[values.m > 0.0]
return MonoSpectralGrid(wavelengths=w_selected)
[docs]
def merge(self, other: MonoSpectralGrid) -> MonoSpectralGrid:
# Inherit docstring
# Collect all wavelengths
w_u = ucc.get("wavelength")
w_m = np.sort(
np.concatenate((self.wavelengths.m_as(w_u), other.wavelengths.m_as(w_u)))
)
# Remove duplicates
w_m = np.unique(w_m)
return MonoSpectralGrid(wavelengths=w_m * w_u)
[docs]
def walk_indices(self) -> t.Generator[MonoSpectralIndex, None, None]:
# Inherit docstring
for w in self.wavelengths:
yield MonoSpectralIndex(w=w)
[docs]
@SpectralGrid.subtypes.register(ModeFlag.SPECTRAL_MODE_CKD)
@define(init=False)
class CKDSpectralGrid(SpectralGrid):
"""
A spectral grid that splits the spectral dimensions into bins characterized
by their bounds.
Parameters
----------
fix_bounds : {"keep_min", "keep_max"} or False, default: "keep_min"
Unless told no to, the constructor will detect lower and upper bound
values close to each other within a tolerance and flag them as matching.
If this parameter is set to ``"keep_min"`` or ``"keep_max"``, the
constructor make sure that matching bounds effectively match exactly.
If it is set to ``"raise"``, it will raise an exception. If it is set to
``"ignore"``, no action will be taken. The tolerance is controlled by
the ``epsilon`` parameter.
epsilon : float, default: 1e-6
Absolute tolerance for matching bound detection.
Raises
------
ValueError
If matching bound misalignment is detected and ``fix_bounds`` is set to
``"raise"``.
"""
wmins: pint.Quantity = documented(
pinttrs.field(units=ucc.deferred("wavelength"), repr=summary_repr),
doc="Lower bound of all bins. Unitless values are interpreted as default "
"wavelength config units.",
type="quantity",
init_type="quantity or array-like",
)
wmaxs: pint.Quantity = documented(
pinttrs.field(units=ucc.deferred("wavelength"), repr=summary_repr),
doc="Upper bound of all bins. Unitless values are interpreted as default "
"wavelength config units.",
type="quantity",
init_type="quantity or array-like",
)
wcenters: pint.Quantity = documented(
pinttrs.field(units=ucc.deferred("wavelength"), repr=summary_repr),
doc="Central wavelength of all bins. Unitless values are interpreted as "
"default wavelength config units. "
"If unset, bin centers are computed automatically from bin bounds. "
"Bin centers are allowed to be different from the middle of the bin "
"interval and, when the grid is tied to a database, are expected to "
"match the values of the wavelength coordinate in the database. However, "
"this is considered as a workaround to deal with poorly indexed databases, "
"and users should try to set central wavelengths to the middle of spectral "
"bins.",
type="quantity, optional",
init_type="quantity or array-like",
)
def __init__(
self,
wmins: npt.ArrayLike,
wmaxs: npt.ArrayLike,
wcenters: npt.ArrayLike | None = None,
fix_bounds: t.Literal["keep_min", "keep_max", "raise", "ignore"] = "keep_min",
epsilon: float = 1e-6,
):
# Ensure consistent units and appropriate dtype
w_u = ucc.get("wavelength")
wmins_m = ensure_units(wmins, w_u).m_as(w_u).astype(np.float64)
wmaxs_m = ensure_units(wmaxs, w_u).m_as(w_u).astype(np.float64)
# Detect bound mismatch
diff_bounds = wmaxs_m[:-1] - wmins_m[1:]
fix_locations = (diff_bounds > 0.0) & (diff_bounds <= epsilon)
if np.any(fix_locations):
if fix_bounds == "keep_max":
wmins_m[1:] = np.where(fix_locations, wmaxs_m[:-1], wmins_m[1:])
elif fix_bounds == "keep_min":
wmaxs_m[:-1] = np.where(fix_locations, wmins_m[1:], wmaxs_m[:-1])
elif fix_bounds == "raise":
raise ValueError(
"while constructing CKDSpectralGrid: bin bound mismatch "
f"(min: {wmins_m[1:][fix_locations]}; max: {wmaxs_m[:-1][fix_locations]}"
)
elif fix_bounds == "ignore":
pass
else:
raise ValueError(f'unknown bound fixing policy "{fix_bounds}"')
# Define bin centers if necessary
if wcenters is None:
wcenters = 0.5 * (wmins_m + wmaxs_m) * w_u
# Initialize the object
self.__attrs_init__(wmins_m * w_u, wmaxs_m * w_u, wcenters)
@property
def wavelengths(self):
# Inherit docstring
return self.wcenters
def plot(self, ax, alpha=0.5):
import seaborn as sns
from cycler import cycler
w_u = ucc.get("wavelength")
color_cycle = cycler(color=sns.color_palette())
for wmin, wmax, wcenter, color in zip(
self.wmins.m_as(w_u),
self.wmaxs.m_as(w_u),
self.wcenters.m_as(w_u),
itertools.cycle(color_cycle),
):
c = color["color"]
ax.fill_between(
[wmin, wmax], 0, 1, color=c, alpha=alpha, lw=0.5, ls=(0, (5, 5))
)
ax.vlines(wcenter, 0, 1, color=c, lw=0.5)
return ax
def _repr_html_(self):
import base64
import io
import matplotlib.pyplot as plt
import seaborn as sns
w_u = ucc.get("wavelength")
fig, ax = plt.subplots(1, 1, figsize=(6, 1))
self.plot(ax)
ax.set_xlabel(f"Wavelength [{w_u:~P}]")
sns.despine(left=True)
ax.axes.get_yaxis().set_visible(False)
img = io.BytesIO()
fig.savefig(img, format="png", bbox_inches="tight")
plt.close(fig)
img.seek(0)
return (
"<img "
f'src="data:image/png;base64, {base64.b64encode(img.getvalue()).decode("utf-8")}" '
"/>"
)
[docs]
@staticmethod
def default() -> CKDSpectralGrid:
"""
Generate a default CKD spectral that covers the default spectral range
with 10 nm spacing.
"""
return CKDSpectralGrid.arange(
start=SPECTRAL_RANGE_MIN.m_as(ureg.nm),
stop=SPECTRAL_RANGE_MAX.m_as(ureg.nm) + 1.0,
step=10.0,
)
[docs]
@staticmethod
def arange(
start: float | pint.Quantity,
stop: float | pint.Quantity,
step: float | pint.Quantity,
) -> CKDSpectralGrid:
"""
Generate a CKD spectral grid with equally-sized bins.
Parameters
----------
start : quantity or float
Central wavelength of the first bin. If a unitless value is passed,
it is interpreted in default wavelength units (usually nm).
stop : quantity or float
Wavelength after which bin generation stops. If a unitless value is
passed, it is interpreted in default wavelength units (usually nm).
step : quantity or float
Spectral bin size. If a unitless value is passed, it is interpreted
in default wavelength units (usually nm).
Returns
-------
CKDSpectralGrid
Generated CKD spectral grid.
"""
w_u = ucc.get("wavelength")
start_m = ensure_units(start, w_u).m_as(w_u)
stop_m = ensure_units(stop, w_u).m_as(w_u)
width_m = ensure_units(step, w_u).m_as(w_u)
wcenters_m = np.arange(start_m, stop_m, width_m)
wmins_m = wcenters_m - 0.5 * width_m
wmaxs_m = wcenters_m + 0.5 * width_m
return CKDSpectralGrid(wmins_m * w_u, wmaxs_m * w_u, wcenters_m * w_u)
@classmethod
def from_nodes(cls, wnodes: npt.ArrayLike) -> CKDSpectralGrid:
wmins = wnodes[:-1]
wmaxs = wnodes[1:]
return cls(wmins=wmins, wmaxs=wmaxs)
[docs]
@classmethod
def from_absorption_database(cls, abs_db: CKDAbsorptionDatabase) -> CKDSpectralGrid:
"""
Retrieve the spectral grid from a CKD absorption database.
Parameters
----------
abs_db : .CKDAbsorptionDatabase
"""
if not isinstance(abs_db, CKDAbsorptionDatabase):
raise TypeError
wmins = abs_db.spectral_coverage["wbound_lower [nm]"].values * ureg.nm
wmaxs = abs_db.spectral_coverage["wbound_upper [nm]"].values * ureg.nm
wcenters = abs_db.spectral_coverage.index.get_level_values(1).values * ureg.nm
return cls(wmins, wmaxs, wcenters)
@singledispatchmethod
def _select_impl(self, srf: SpectralResponseFunction) -> CKDSpectralGrid:
# Inherit docstring
raise NotImplementedError(f"unsupported data type '{type(srf)}'")
@_select_impl.register
def _(self, srf: DeltaSRF):
w_u = srf.wavelengths.u
w_m = srf.wavelengths.m
wmins_m = self.wmins.m_as(w_u)
wmaxs_m = self.wmaxs.m_as(w_u)
selmin = np.searchsorted(wmins_m, w_m)
selmax = np.searchsorted(wmaxs_m, w_m) + 1
hit = selmin == selmax # Mask where w_m values which triggered a bin hit
# Map w values to selected bin (index -999 means not selected)
bin_index = np.where(hit, selmin - 1, np.full_like(w_m, -999)).astype(np.int64)
# Get selected bins only
selected = np.unique(bin_index) # mask removes -999 value
selected = selected[selected >= 0]
return CKDSpectralGrid(wmins=self.wmins[selected], wmaxs=self.wmaxs[selected])
@_select_impl.register
def _(self, srf: UniformSRF):
selected = (self.wmaxs > srf.wmin) & (self.wmins < srf.wmax)
return CKDSpectralGrid(wmins=self.wmins[selected], wmaxs=self.wmaxs[selected])
@_select_impl.register
def _(self, srf: BandSRF):
w_u = self.wmins.u
wmins_m = self.wmins.m_as(w_u)
wmaxs_m = self.wmaxs.m_as(w_u)
# Build spectral mesh used to interpolate
w_m = np.unique(np.concatenate((wmins_m, wmaxs_m)))
# Note the handling of numeric precision-induced min and max bin bound
# mismatch was removed from the previous implementation because
# consistency is enforced upon initialization
# Detect spectral bins on which the SRF takes nonzero values
cumsum = np.concatenate(([0], srf.integrate_cumulative(w_m * w_u).m_as(w_u)))
selected = cumsum[:-1] != cumsum[1:]
# Build a new spectral grid that only contains selected bins
return CKDSpectralGrid(self.wmins[selected], self.wmaxs[selected])
[docs]
def merge(self, other: CKDSpectralGrid) -> CKDSpectralGrid:
# Inherit docstring
# Collect spectral bin information
w_u = ucc.get("wavelength")
wmins = np.concatenate((self.wmins.m_as(w_u), other.wmins.m_as(w_u)))
wmaxs = np.concatenate((self.wmaxs.m_as(w_u), other.wmaxs.m_as(w_u)))
wcenters = np.concatenate((self.wcenters.m_as(w_u), other.wcenters.m_as(w_u)))
# Sort bins and remove duplicates
# TODO: Vectorize this for best performance, this is quick and dirty
w_m = sorted(
np.stack((wmins, wmaxs, wcenters)).T.tolist(),
key=lambda x: (x[0], x[1], x[2]),
)
w_m = np.array(deduplicate_sorted(w_m))
return CKDSpectralGrid(
wmins=w_m[:, 0] * w_u, wmaxs=w_m[:, 1] * w_u, wcenters=w_m[:, 2] * w_u
)
[docs]
def walk_quads(
self,
ckd_quad_config: CKDQuadConfig,
abs_db: CKDAbsorptionDatabase | None = None,
) -> t.Generator[tuple[pint.Quantity, Quad]]:
"""
Walk the spectral grid and retrieve, based on a quadrature configuration
and, if necessary, an absorption database, the spectral quadrature for
each spectral bin.
Parameters
----------
ckd_quad_config : .CKDQuadConfig
CKD quadrature configuration.
abs_db : .CKDAbsorptionDatabase, optional
Molecular absorption database used to build quadrature rules for
each spectral bin. This parameter is required only if an adaptive
quadrature generation policy is used, otherwise it is ignored.
Yields
------
quad : .Quad
Quadrature rule for the current spectral bin.
w : quantity
Wavelength of the current spectral bin.
"""
# Check parameter consistency
if ckd_quad_config.policy is not CKDQuadPolicy.FIXED and abs_db is None:
raise ValueError(
"while attempting CKD spectral grid walk with policy "
f"{ckd_quad_config.policy}: `abs_db` must be set (got None)"
)
# Walk the spectral grid and get the quadrature for each bin
for w in self.wcenters:
yield w, ckd_quad_config.get_quad(abs_db, wcenter=w)
[docs]
def walk_indices(
self,
ckd_quad_config: CKDQuadConfig,
abs_db: CKDAbsorptionDatabase | None = None,
) -> t.Generator[CKDSpectralIndex]:
"""
Walk the spectral grid and retrieve, based on a quadrature configuration
and, if necessary, an absorption database, the sequence of spectral
indexes driving the spectral loop.
Parameters
----------
ckd_quad_config : .CKDQuadConfig
CKD quadrature configuration.
abs_db : .CKDAbsorptionDatabase, optional
Molecular absorption database used to build quadrature rules for
each spectral bin. This parameter is required only if an adaptive
quadrature generation policy is used, otherwise it is ignored.
Yields
------
si : .CKDSpectralIndex
Generated spectral index.
"""
# Walk the spectral dimension
for w, quad in self.walk_quads(ckd_quad_config, abs_db):
for g in quad.eval_nodes([0, 1]):
yield CKDSpectralIndex(w=w, g=g)