from __future__ import annotations
import logging
import typing as t
import attrs
import numpy as np
import pint
import pinttr
import xarray as xr
from .index import CKDSpectralIndex
from .spectral_set import SpectralSet
from ..attrs import documented, parse_docs
from ..constants import SPECTRAL_RANGE_MAX, SPECTRAL_RANGE_MIN
from ..quad import Quad, QuadType
from ..units import to_quantity
from ..units import unit_context_config as ucc
from ..units import unit_registry as ureg
from ..util.misc import round_to_multiple
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
# CKD bin data classes
# ------------------------------------------------------------------------------
[docs]
@parse_docs
@attrs.define(eq=False, frozen=True, slots=True)
class Bin:
"""
A data class representing a spectral bin in CKD modes.
Notes
-----
A bin is more than a spectral interval. It is associated with a
quadrature rule.
"""
wmin: pint.Quantity = documented(
pinttr.field(
units=ucc.deferred("wavelength"),
on_setattr=None, # frozen instance: on_setattr must be disabled
),
doc="Bin lower spectral bound.\n\nUnit-enabled field "
'(default: ucc["wavelength"]).',
type="quantity",
init_type="quantity or float",
)
wmax: pint.Quantity = documented(
pinttr.field(
units=ucc.deferred("wavelength"),
on_setattr=None, # frozen instance: on_setattr must be disabled
),
doc="Bin upper spectral bound.\n\nUnit-enabled field "
'(default: ucc["wavelength"]).',
type="quantity",
init_type="quantity or float",
)
@wmin.validator
@wmax.validator
def _wbounds_validator(self, attribute, value):
if not self.wmin < self.wmax:
raise ValueError(
f"while validating {attribute.name}: wmin must be lower than wmax"
)
quad: Quad = documented(
attrs.field(
factory=lambda: Quad.gauss_legendre(2),
repr=lambda x: x.str_summary,
validator=attrs.validators.instance_of(Quad),
),
doc="Quadrature rule attached to the CKD bin.",
type=":class:`.Quad`",
)
@property
def width(self) -> pint.Quantity:
"""quantity : Bin spectral width."""
return self.wmax - self.wmin
@property
def wcenter(self) -> pint.Quantity:
"""quantity : Bin central wavelength."""
return 0.5 * (self.wmin + self.wmax)
@property
def pretty_repr(self) -> str:
"""str : Pretty representation of the bin."""
units = ureg.Unit("nm")
wrange = (
f"[{self.wmin.m_as(units):.1f}, {self.wmax.m_as(units):.1f}] {units:~P}"
)
quad = self.quad.pretty_repr()
return f"{wrange}: {quad}"
def spectral_indices(self) -> t.Generator[CKDSpectralIndex]:
for value in self.quad.eval_nodes(interval=[0.0, 1.0]):
yield CKDSpectralIndex(w=self.wcenter, g=value)
# ------------------------------------------------------------------------------
# CKD quadrature setup classes
# ------------------------------------------------------------------------------
[docs]
@parse_docs
@attrs.define
class QuadSpec:
"""
Abstract base class for all quadrature specification patterns.
Each subclass defines a strategy used to generate a spectral quadrature
corresponding to a CKD dataset and must implement the strategy in the
:meth:`make_quad`.
"""
[docs]
@staticmethod
def default() -> QuadSpecFixed:
"""
Return the default spectral quadrature (Gauss-Legendre, 16 *g*-points).
"""
return QuadSpecFixed(n=16, quad_type="gauss_legendre")
[docs]
@staticmethod
def from_dict(
value: dict[str, t.Any],
) -> QuadSpecFixed | QuadSpecMinError | QuadSpecErrorThreshold:
"""
Create a quadrature specification subtype from a dictionary. The
dictionary must have a ``type`` entry, whose value maps to a give
quadrature specification subtype as follows:
* ``fixed``: :class:`.QuadSpecFixed`
* ``minimize_error``: :class:`.QuadSpecMinError`
* ``error_threshold``: :class:`.QuadSpecErrorThreshold`
Parameters
----------
value : dict
A dictionary mapping parameter names to their respective values.
Returns
-------
QuadSpec
"""
try:
subtype: str = value.pop("type")
except KeyError:
raise ValueError("dictionary input must have a 'type' entry")
if subtype == "fixed":
cls = QuadSpecFixed
elif subtype in {"minimize", "minimize_error"}:
cls = QuadSpecMinError
elif subtype in {"threshold", "error_threshold"}:
cls = QuadSpecErrorThreshold
else:
raise ValueError(f"Unknown quadrature specification '{subtype}'")
return cls.from_dict(value)
[docs]
@classmethod
def convert(cls, value: t.Any) -> QuadSpec:
"""
Attempt conversion to a :class:`.QuadSpec` instance. If `value` is a
dictionary, it is passed to :meth:`.from_dict`; otherwise, it is left
unchanged.
"""
if isinstance(value, dict):
return cls.from_dict(value)
else:
return value
[docs]
def make_quad(self, dataset: xr.Dataset) -> Quad:
"""
Apply the quadrature generation strategy and generate a quadrature rule
for a given dataset.
Parameters
----------
dataset : Dataset
An xarray dataset following the CKD absorption data format, for
which a quadrature rule is generated.
Returns
-------
.Quad
"""
raise NotImplementedError
[docs]
@parse_docs
@attrs.define
class QuadSpecFixed(QuadSpec):
"""
Fixed number of quadrature points [``fixed``]
Use a fixed number of quadrature points for all bins. If the quadrature
is specified this way, the quadrature type has to be explicitly specified
using the ``type`` field.
"""
n: int = documented(
attrs.field(),
doc="Number of quadrature points",
type="int",
)
quad_type: QuadType = documented(
attrs.field(default="gauss_legendre", converter=QuadType),
doc="Quadrature type",
type=".QuadType",
init_type=".QuadType or str",
default='"gauss_legendre"',
)
[docs]
@classmethod
def from_dict(cls, value: dict[str, t.Any]) -> QuadSpecFixed:
return cls(**value)
[docs]
def make_quad(self, dataset: xr.Dataset) -> Quad:
# Inherit docstring
return Quad.new(type=self.quad_type, n=self.n)
[docs]
@parse_docs
@attrs.define
class QuadSpecMinError(QuadSpec):
"""
Error-minimizing number of quadrature points [``minimize_error``]
Find the number of quadrature points that minimizes the error on the
atmospheric transmittance. The quadrature type
"""
nmax: int | None = documented(
attrs.field(
default=None,
converter=attrs.converters.optional(int),
),
doc="Maximum number of quadrature points",
type="int or None",
init_type="int, optional",
)
[docs]
@classmethod
def from_dict(cls, value: dict[str, t.Any]) -> QuadSpecMinError:
return cls(**value)
[docs]
def make_quad(self, dataset: xr.Dataset) -> Quad:
# Inherit docstring
n = ng_minimum(error=dataset.error, ng_max=self.nmax)
quad_type = dataset.ng.attrs.get("quadrature_type", "gauss_legendre")
return Quad.new(type=quad_type, n=n)
[docs]
@parse_docs
@attrs.define
class QuadSpecErrorThreshold(QuadSpec):
"""
Error-threshold number of quadrature points [``error_threshold``]
Find the number of quadrature points so that the error on the atmospheric
transmittance is below a specified threshold.
"""
threshold: float = documented(
attrs.field(),
doc="Error threshold value",
type="float",
)
nmax: int | None = documented(
attrs.field(
default=None,
converter=attrs.converters.optional(int),
),
doc="Maximum number of quadrature points",
type="int or None",
init_type="int, optional",
)
[docs]
@classmethod
def from_dict(cls, value: dict[str, t.Any]) -> QuadSpecErrorThreshold:
return cls(**value)
[docs]
def make_quad(self, dataset: xr.Dataset) -> Quad:
# Inherit docstring
quad_type = dataset.ng.attrs.get("quadrature_type", "gauss_legendre")
n = ng_threshold(
error=dataset.error, threshold=self.threshold, ng_max=self.nmax
)
return Quad.new(type=quad_type, n=n)
def ng_minimum(error: xr.DataArray, ng_max: int | None = None):
"""
Find the number of quadrature points that minimizes the error.
Parameters
----------
error : DataArray
Error data.
ng_max : int, optional
Maximum number of quadrature points. If not provided, it will be
inferred from the error data.
Returns
-------
int
Number of quadrature points that minimizes the error.
"""
if ng_max is None:
ng_max = int(error.ng.max())
error_w0 = error.isel(w=0)
ng_min = int(error.ng.where(error_w0 == error_w0.min(), drop=True)[0])
return ng_max if ng_min > ng_max else ng_min
def ng_threshold(
error: xr.DataArray,
threshold: float,
ng_max: int | None = None,
):
"""
Find the number of quadrature points so that the error is (strictly) below
a specified threshold value.
Parameters
----------
error : DataArray
Error data.
threshold : float
Error threshold.
ng_max : int, optional
Maximum number of quadrature points. If not provided, it will be
inferred from the error data.
Returns
-------
int
Number of quadrature points so that the error is below the threshold.
"""
if ng_max is None:
ng_max = int(error.ng.max())
error_w0 = error.isel(w=0)
ng = error.ng.where(error_w0 < threshold, drop=True)
if ng.size == 0:
return ng_max
else:
ng = int(ng[0])
return ng_max if ng > ng_max else ng
# ------------------------------------------------------------------------------
# Bin set data class
# ------------------------------------------------------------------------------
[docs]
@parse_docs
@attrs.define(eq=False, frozen=True, slots=True)
class BinSet(SpectralSet):
"""
A data class representing a bin set used in CKD mode.
See Also
--------
:class:`.WavelengthSet`
"""
bins: list[Bin] = documented(
attrs.field(
converter=list,
validator=attrs.validators.deep_iterable(
member_validator=attrs.validators.instance_of(Bin)
),
),
doc="Set of bins.",
type="list of :class:`.Bin`",
init_type="iterable of :class:`.Bin`",
)
def spectral_indices(self) -> t.Generator[CKDSpectralIndex]:
for bin in self.bins:
yield from bin.spectral_indices()
@property
def wavelengths(self) -> pint.Quantity:
return self.wcenters
@property
def wcenters(self) -> pint.Quantity:
"""
Return the central wavelength of all bins.
"""
units = ucc.get("wavelength")
return [bin.wcenter.m_as(units) for bin in self.bins] * units
@property
def wmins(self) -> pint.Quantity:
"""
Return the lower bound of all bins.
"""
units = ucc.get("wavelength")
return [bin.wmin.m_as(units) for bin in self.bins] * units
@property
def wmaxs(self) -> pint.Quantity:
"""
Return the upper bound of all bins.
"""
units = ucc.get("wavelength")
return [bin.wmax.m_as(units) for bin in self.bins] * units
[docs]
@classmethod
@ureg.wraps(None, (None, "nm", "nm", "nm", None), strict=False)
def arange(
cls,
start: pint.Quantity,
stop: pint.Quantity,
step: pint.Quantity = 10.0 * ureg.nm,
quad: Quad | None = None,
) -> BinSet:
"""
Generate a bin set with linearly spaced bins.
Parameters
----------
start : quantity or float
Lower bound of first bin. If a float is passed, it is interpreted as
being in units of nm.
stop : quantity
Upper bound of last bin. If a float is passed, it is interpreted as
being in units of nm.
step : quantity, default: 10 nm
Bin width. If a float is passed, it is interpreted as being in units
of nm.
quad : .Quad, optional
Quadrature rule (same for all bins in the set). Defaults to
a one-point Gauss-Legendre quadrature.
Returns
-------
:class:`.BinSet`
Generated bin set.
"""
if quad is None:
quad = Quad.gauss_legendre(1)
wmins = np.arange(start, stop, step)
wmaxs = wmins + step
wunits = ucc.get("wavelength")
bins = [
Bin(
wmin=(wmin * ureg.nm).to(wunits),
wmax=(wmax * ureg.nm).to(wunits),
quad=quad,
)
for wmin, wmax in zip(wmins, wmaxs)
]
return cls(bins)
[docs]
@classmethod
def from_srf(
cls,
srf: xr.Dataset,
step: pint.Quantity = 10.0 * ureg.nm,
quad: Quad | None = None,
) -> BinSet:
"""
Generate a bin set with linearly spaced bins covering the spectral
range of a spectral response function.
Parameters
----------
srf : Dataset
Spectral response function dataset.
step : quantity
Wavelength step.
quad : .Quad, optional
Quadrature rule (same for all bins in the set). Defaults to
a one-point Gauss-Legendre quadrature.
Returns
-------
:class:`.BinSet`
Generated bin set.
"""
wavelengths = to_quantity(srf.w)
wmin = wavelengths.min()
wmax = wavelengths.max()
return cls.arange(start=wmin - step, stop=wmax + step, step=step, quad=quad)
@classmethod
def from_wavelength_bounds(
cls, wmin: pint.Quantity, wmax: pint.Quantity, quad: Quad | None = None
) -> BinSet:
if quad is None:
quad = Quad.gauss_legendre(1)
return cls(
bins=[
Bin(wmin=_wmin, wmax=_wmax, quad=quad)
for _wmin, _wmax in zip(np.atleast_1d(wmin), np.atleast_1d(wmax))
]
)
[docs]
@classmethod
def from_absorption_data(
cls,
datasets: xr.Dataset | t.Sequence[xr.Dataset],
quad_spec: QuadSpec | None = None,
) -> BinSet:
"""
Generate a bin set from one or several absorption datasets.
Parameters
----------
datasets : Dataset or sequence of Dataset
Absorption dataset.
quad_spec : .QuadSpec
Quadrature rule specification. If provided, it will be used to
generate the quadrature rule based on error data in the
absorption dataset.
Returns
-------
:class:`.BinSet`
Generated bin set.
Notes
-----
Assumes that the absorption datasets have a ``wbounds`` data variable.
"""
if isinstance(datasets, xr.Dataset):
datasets = [datasets]
if quad_spec is None:
quad_spec = QuadSpec.default()
bins = []
for dataset in datasets:
# make quadrature rule
quad = quad_spec.make_quad(dataset)
# determine wavelength bounds
wlower = to_quantity(dataset.wbounds.sel(wbv="lower"))
wupper = to_quantity(dataset.wbounds.sel(wbv="upper"))
if wlower.check("[length]"):
wmin = wlower
wmax = wupper
elif wlower.check("[length]^-1"):
wmin = (1.0 / wupper).to("nm") # min wavelength is max wavenumber
wmax = (1.0 / wlower).to("nm") # max wavelength is min wavenumber
else:
raise ValueError(
f"Invalid dimensionality for dataset spectral coordinate; "
f"expected [length] or [length]^-1 "
f"(got {wlower.dimensionality})"
)
binset = cls.from_wavelength_bounds(wmin=wmin, wmax=wmax, quad=quad)
bins.extend(binset.bins)
return cls(bins=bins)
[docs]
@classmethod
def default(cls):
"""
Generate a default bin set, which covers Eradiate's default spectral
range with 10 nm-wide bins.
"""
wmin = round_to_multiple(SPECTRAL_RANGE_MIN.m_as(ureg.nm), 10.0, "nearest")
wmax = round_to_multiple(SPECTRAL_RANGE_MAX.m_as(ureg.nm), 10.0, "nearest")
dw = 10.0
return BinSet.arange(
start=wmin * ureg.nm, stop=wmax * ureg.nm, step=dw * ureg.nm
)