Source code for eradiate.plot

from __future__ import annotations

import typing as t
import warnings

import matplotlib.pyplot as _plt
import numpy as np
import seaborn as sns
import xarray as xr
import xarray.plot
from cycler import cycler
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.gridspec import SubplotSpec
from xarray.plot import FacetGrid

from eradiate.radprops import CKDAbsorptionDatabase, MonoAbsorptionDatabase


[docs] def set_style(rc=None): """ Apply Eradiate Matplotlib style (the Seaborn ``ticks`` style with the viridis colormap). Parameters ---------- rc : dict, optional A Matplotlib rc parameter dictionary to be applied in addition to the Eradiate style. """ try: import seaborn except ModuleNotFoundError: warnings.warn( "To use Eradiate's Matplotlib style, you must install Seaborn.\n" "See instructions on https://seaborn.pydata.org/installing.html." ) raise if rc is None: rc = {} seaborn.set_theme(style="ticks", rc={"image.cmap": "viridis", **rc})
[docs] def detect_axes(from_=None): """ Try and extract a :class:`~matplotlib.axes.Axes` list from a data structure. Parameters ---------- from_ : matplotlib figure or axes or list of axes or :class:`xarray.plot.FacetGrid`, optional Data structure to get an :class:`~matplotlib.axes.Axes` list from. If unset, :func:`matplotlib.pyplot.gca()` is used. Returns ------- list of matplotlib axes Extracted list of :class:`~matplotlib.axes.Axes`. Raises ------ TypeError If ``from_`` is of unsupported type. """ if from_ is None: from_ = _plt.gca() if isinstance(from_, Figure): return from_.axes if isinstance(from_, Axes): return [from_] if isinstance(from_, FacetGrid): return list(from_.axes.flatten()) if isinstance(from_, list): if all([isinstance(x, Axes) for x in from_]): return from_ raise TypeError("unsupported type")
[docs] def get_axes_from_facet_grid( facet_grid: xarray.plot.FacetGrid, exclude: str = None ) -> list[Axes]: """ Extract a flat list of :class:`~matplotlib.axes.Axes` from a :class:`~xarray.plot.FacetGrid`. Parameters ---------- facet_grid : :class:`xarray.plot.FacetGrid` Object to extract a list of :class:`~matplotlib.axes.Axes`. exclude : str, optional Exclude selected :class:`~matplotlib.axes.Axes` from the returned list. Supported values: * ``"lower_left"``: exclude lower-left corner. Returns ------- list of matplotlib axes List of extracted :class:`~matplotlib.axes.Axes` objects. """ index = np.full((facet_grid.axes.size,), True, dtype=bool) n_rows, n_cols = facet_grid.axes.shape # This filtering system can be extended in the future if exclude == "lower_left": index[(n_rows - 1) * n_cols] = False return list(facet_grid.axes.flatten()[index])
[docs] def remove_xylabels(from_=None) -> None: """ Remove x- and y-axis labels from ``from_`` (processed by :func:`detect_axes`). Parameters ---------- from_ : matplotlib figure or axes or list of axes or :class:`xarray.plot.FacetGrid`, optional Data structure to get an :class:`~matplotlib.axes.Axes` list from. If unset, :func:`matplotlib.pyplot.gca()` is used. See Also -------- :func:`detect_axes` """ for ax in detect_axes(from_): ax.set_xlabel("") ax.set_ylabel("")
[docs] def remove_xyticks(from_=None) -> None: """ Remove x and y axis tick labels from ``from_`` (processed by :func:`detect_axes`). Parameters ---------- from_ : matplotlib figure or axes or list of axes or :class:`xarray.plot.FacetGrid`, optional Data structure to get an :class:`~matplotlib.axes.Axes` list from. If unset, :func:`matplotlib.pyplot.gca()` is used. See Also -------- :func:`detect_axes` """ for ax in detect_axes(from_): ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)
[docs] def make_ticks(num_ticks: int, limits: t.Sequence[float]): """ Generate ticks and their respective tickmarks. Parameters ---------- num_ticks : int Number of ticks to generate, including the limits of the given range limits : tuple[float, float] List of two values, limiting the ticks inclusive Returns ------ steps : list of float Values for the ticks. marks : list of str Tick values converted to degrees as string tickmarks. """ step_width = float(limits[1] - limits[0]) / (num_ticks - 1) steps = [limits[0] + step_width * i for i in range(num_ticks)] marks = [f"{i / np.pi * 180}°" for i in steps] return steps, marks
# -- Data dashboard definitions ------------------------------------------------ class PiecewiseNorm(Normalize): def __init__(self, levels, clip=False): # input levels self._levels = np.sort(levels) # corresponding normalized values between 0 and 1 self._normed = np.linspace(0, 1, len(levels)) Normalize.__init__(self, None, None, clip) def __call__(self, value, clip=None): # linearly interpolate to get the normalized value return np.ma.masked_array(np.interp(value, self._levels, self._normed)) def inverse(self, value): return 1.0 - self.__call__(value) def _chunk_array(array, len_chunk): n_chunks = max(1, round(len(array) / len_chunk)) return np.array_split(array, n_chunks) def dashboard_particle_dataset( ds: xr.Dataset, title: str | None = None, phase_wavelength_chunk_size: int = 5, palette="muted", ): """ Create a summary dashboard plot for a particle dataset. The generated plot displays the extinction coefficient, the albedo and phase function. Parameters ---------- ds : Dataset Particle (aerosol, cloud, etc.) dataset. title : str, optional A title for the plot. phase_wavelength_chunk_size : int, default: 5 The produced dashboard shows the phase function clustered in chunks on a number of polar plots set by this chunk size. palette : str, list, dict, or matplotlib.colors.Colormap, optional A Seaborn palette specification used to colour the phase function line plots. """ df = ds.phase.isel(i=0, j=0, drop=True).to_dataframe()[["phase"]] df = df.sort_values(["w", "mu"]).reset_index() df["theta"] = np.arccos(df["mu"]) phase_min = df["phase"].min() phase_max = df["phase"].max() print(f"{phase_min = }, {phase_max = }") w_chunks = [ (chunk.min(), chunk.max()) for chunk in _chunk_array(np.unique(df["w"]), phase_wavelength_chunk_size) ] naxs_phase = len(w_chunks) ncols = 2 nrows_phase = naxs_phase // ncols + (1 if naxs_phase % ncols else 0) height_ratios = [0.5, 0.5] + [1 for _ in range(nrows_phase)] fig, axs = plt.subplot_mosaic( [ ["sigma_t", "sigma_t"], ["albedo", "albedo"], ] + [[f"phase_{i * 2}", f"phase_{i * 2 + 1}"] for i in range(nrows_phase)], height_ratios=height_ratios, figsize=(4 * ncols, 4 * sum(height_ratios)), subplot_kw={"projection": "polar"}, layout="constrained", ) for i_var, var in enumerate(["sigma_t", "albedo"]): gs = axs[var].get_gridspec() start = i_var * ncols stop = start + 1 axs[var].remove() axs[var] = plt.subplot(SubplotSpec(gs, start, stop), projection=None) ds[var].plot(ax=axs[var], ls=":", marker=".") for i_phase, (wmin, wmax) in enumerate(w_chunks): _df = df.where((df["w"] >= wmin) & (df["w"] <= wmax)).dropna() ax = axs[f"phase_{i_phase}"] sns.lineplot( ax=ax, data=_df, x="theta", y="phase", hue="w", hue_norm=PiecewiseNorm(levels=_df["w"].values), legend="full", palette=palette, ) ax.set_rlim([phase_min * 0.5, phase_max * 5.0]) ax.set_rscale("log") ax.set_thetagrids(np.arange(0, 181, 45)) ax.set_thetamax(180) ax.set_xlabel("") ax.set_ylabel("") # Deal with legend entries lgd_title = "wavelength" # default value, overridden by metadata if any for field in ["long_name", "standard_name"]: if field in ds.w.attrs: lgd_title = ds.w.attrs[field] break if "units" in ds.w.attrs: lgd_title += f" [{ds.w.units}]" lgd = ax.legend( loc="upper center", ncol=3, bbox_to_anchor=(0.5, 0.1), title=lgd_title, ) new_labels = [f"{w:.3g}" for w in np.unique(_df["w"].values)] for i, _ in enumerate(lgd.get_texts()): lgd.get_texts()[i].set_text(new_labels[i]) # Clear empty axes for i in range(i_phase + 1, nrows_phase * 2): ax = axs[f"phase_{i}"] ax.remove() if title: fig.suptitle(title) return fig, axs def absorption_database_spectral_coverage_mono( db: MonoAbsorptionDatabase, wrange=None, title=None ): df_index = db._index fig, ax = plt.subplots(1, 1, figsize=(6, 1.5)) color_cycle = cycler(color=sns.color_palette()) for i, (filename, color) in enumerate(zip(df_index["filename"], color_cycle())): color = color["color"] spectral_coverage = db.spectral_coverage.loc[filename] w = spectral_coverage.index.values sns.histplot(x=w, color=color, ax=ax) ax.set_xlabel("Wavelength [nm]") if title: ax.set_title(title) sns.despine(offset=10) return fig, ax def absorption_database_spectral_coverage_ckd( db: CKDAbsorptionDatabase, wrange=None, title=None ): df_index = db._index wmins = df_index["wl_min [nm]"] wmaxs = df_index["wl_max [nm]"] if wrange is not None: select = (wmins >= wrange[0]) & (wmaxs <= wrange[1]) df_index = df_index.loc[select] wmins = df_index["wl_min [nm]"] wmaxs = df_index["wl_max [nm]"] fig, ax = plt.subplots(1, 1, figsize=(6, 1)) color_cycle = cycler(color=sns.color_palette()) for i, (filename, color) in enumerate(zip(df_index["filename"], color_cycle())): color = color["color"] spectral_coverage = db.spectral_coverage.loc[filename] for _, (wmin, wmax) in spectral_coverage.iterrows(): ax.fill_between([wmin, wmax], 0, 1, color=color, alpha=0.5) if i == 0: w = wmins.min() ax.scatter(w, 0.5, color=color, marker=8) ax.annotate( f"{w:4.2f}", (w, 0.5), color=color, textcoords="offset points", xytext=(-6, 0), ha="right", va="center", ) else: w = wmaxs.max() ax.scatter(w, 0.5, color=color, marker=9) ax.annotate( f"{w:4.2f}", (w, 0.5), color=color, textcoords="offset points", xytext=(6, 0), ha="left", va="center", ) ax.set_xlabel("Wavelength [nm]") ax.set_yticks([]) if title: ax.set_title(title) sns.despine(ax=ax, top=True, left=True, right=True, bottom=False) return fig, ax