Source code for eradiate.scenes.core

from __future__ import annotations

import typing as t
import warnings
from abc import ABC, abstractmethod
from collections import abc as collections_abc
from typing import Mapping, Sequence

import attrs
import mitsuba as mi
import numpy as np
import pint
import pinttr
from pinttr.util import ensure_units

from ..attrs import documented, parse_docs
from ..contexts import KernelDictContext
from ..exceptions import KernelVariantError
from ..units import unit_context_config as ucc
from ..units import unit_registry as ureg
from ..util.misc import onedict_value


def _kernel_dict_get_mitsuba_variant() -> str:
    variant = mi.variant()

    if variant:
        return variant
    else:
        raise KernelVariantError(
            "a kernel variant must be selected to create a KernelDict instance"
        )


[docs]@parse_docs @attrs.define class KernelDict(collections_abc.MutableMapping): """ A dictionary-like object designed to contain a scene specification appropriate for instantiation with :func:`~mitsuba.core.load_dict`. :class:`.KernelDict` keeps track of the variant it has been created with and performs minimal checks to help prevent inconsistent scene creation. """ data: dict = documented( attrs.field( factory=dict, converter=dict, ), doc="Scene dictionary.", default="{}", type="dict", ) post_load: dict = documented( attrs.field( factory=dict, converter=dict, ), doc="Post-load update dictionary.", default="{}", type="dict", ) variant: str = documented( attrs.field( factory=_kernel_dict_get_mitsuba_variant, validator=attrs.validators.instance_of(str), ), doc="Kernel variant for which the dictionary is created. Defaults to " "currently active variant (if any; otherwise raises).", type="str", default=":func:`mitsuba.set_variant`", ) def __getitem__(self, k): return self.data.__getitem__(k) def __delitem__(self, v): return self.data.__delitem__(v) def __len__(self): return self.data.__len__() def __iter__(self): return self.data.__iter__() def __setitem__(self, k, v): try: self.data.__getitem__(k) warnings.warn( f"Duplicate key '{k}' will be overwritten. Are you trying to " "add scene elements with duplicate IDs to this KernelDict?" ) except KeyError: pass return self.data.__setitem__(k, v)
[docs] def check(self) -> None: """ Perform basic checks on the dictionary: * check that the ``"type"`` parameter is included; * check if the variant for which the kernel dictionary was created is the same as the current one. Raises ------ ValueError If the ``"type"`` parameter is missing. :class:`.KernelVariantError` If the variant for which the kernel dictionary was created is not the same as the current one """ variant = mi.variant() if self.variant != variant: raise KernelVariantError( f"scene dictionary created for kernel variant '{self.variant}', " f"incompatible with current variant '{variant}'" ) if "type" not in self: raise ValueError("kernel scene dictionary is missing a 'type' parameter")
def fix(self) -> None: if "type" not in self.data: self.data["type"] = "scene"
[docs] def load( self, strip: bool = True, post_load_update: bool = True ) -> "mitsuba.Object": """ Call :func:`~mitsuba.core.load_dict` on self. In addition, a post-load update can be applied. If the encapsulated dictionary misses a ``"type"`` key, it will be promoted to a scene dictionary through the addition of ``{"type": "scene"}``. For instance, it means that .. code:: python { "shape1": {"type": "sphere"}, "shape2": {"type": "sphere"}, } will be interpreted as .. code:: python { "type": "scene", "shape1": {"type": "sphere"}, "shape2": {"type": "sphere"}, } .. note:: Requires a valid selected operational mode. Parameters ---------- strip : bool If ``True``, if ``data`` has no ``'type'`` entry and if ``data`` consists of one nested dictionary, it will be loaded directly. For instance, it means that .. code:: python {"phase": {"type": "rayleigh"}} will be stripped to .. code:: python {"type": "rayleigh"} post_load_update : bool If ``True``, use :func:`~mitsuba.python.util.traverse` and update loaded scene parameters according to data stored in ``post_load``. Returns ------- :class:`mitsuba.core.Object` Loaded Mitsuba object. """ d = self.data d_extra = {} if "type" not in self: if len(self) == 1 and strip: # Extract plugin dictionary d = onedict_value(d) else: # Promote to scene dictionary d_extra = {"type": "scene"} obj = mi.load_dict({**d, **d_extra}) if self.post_load and post_load_update: params = mi.traverse(obj) params.keep(list(self.post_load.keys())) for k, v in self.post_load.items(): params[k] = v params.update() return obj
[docs] def add(self, *elements: SceneElement, ctx: KernelDictContext) -> None: """ Merge the content of a :class:`~eradiate.scenes.core.SceneElement` or another dictionary object with the current :class:`KernelDict`. Parameters ---------- *elements : :class:`SceneElement` :class:`~eradiate.scenes.core.SceneElement` instances to add to the scene dictionary. ctx : :class:`.KernelDictContext` A context data structure containing parameters relevant for kernel dictionary generation. *This argument is keyword-only and required.* """ for element in elements: self.update(element.kernel_dict(ctx))
[docs] def merge(self, other: KernelDict) -> None: """ Merge another :class:`.KernelDict` with the current one. Parameters ---------- other : :class:`.KernelDict` A kernel dictionary whose main and post-load dictionaries will be used to update the current one. """ if self.variant != other.variant: raise KernelVariantError("merged kernel dicts must share the same variant") self.data.update(other.data) self.post_load.update(other.post_load)
[docs] @classmethod def from_elements( cls, *elements: SceneElement, ctx: KernelDictContext ) -> KernelDict: """ Create a new :class:`.KernelDict` from one or more scene elements. Parameters ---------- *elements : :class:`SceneElement` :class:`~eradiate.scenes.core.SceneElement` instances to add to the scene dictionary. ctx : :class:`.KernelDictContext` A context data structure containing parameters relevant for kernel dictionary generation. *This argument is keyword-only and required.* Returns ------- :class:`KernelDict` Created scene kernel dictionary. """ result = cls() result.add(*elements, ctx=ctx) return result
[docs]@parse_docs @attrs.define class SceneElement(ABC): """ Abstract class for all scene elements. This abstract base class provides a basic template for all scene element classes. It is written using the `attrs <https://www.attrs.org>`_ library. """ id: t.Optional[str] = documented( attrs.field( default=None, validator=attrs.validators.optional(attrs.validators.instance_of(str)), ), doc="User-defined object identifier.", type="str or None", init_type="str, optional", default="None", ) def _kernel_dict_id(self) -> t.Dict: """ Return a scene dictionary entry with the object's ``id`` field if it is not ``None``. """ result = {} if self.id is not None: result["id"] = self.id return result
[docs] @abstractmethod def kernel_dict(self, ctx: KernelDictContext) -> KernelDict: """ Return a dictionary suitable for kernel scene configuration. Parameters ---------- ctx : :class:`.KernelDictContext` A context data structure containing parameters relevant for kernel dictionary generation. Returns ------- :class:`.KernelDict` Kernel dictionary which can be loaded as a Mitsuba object. """ pass
[docs]@parse_docs @attrs.frozen class BoundingBox: """ A basic data class representing an axis-aligned bounding box with unit-valued corners. Notes ----- Instances are immutable. """ min: pint.Quantity = documented( pinttr.field( units=ucc.get("length"), on_setattr=None, # frozen instance: on_setattr must be disabled ), type="quantity", init_type="array-like or quantity", doc="Min corner.", ) max: pint.Quantity = documented( pinttr.field( units=ucc.get("length"), on_setattr=None, # frozen instance: on_setattr must be disabled ), type="quantity", init_type="array-like or quantity", doc="Max corner.", ) @min.validator @max.validator def _min_max_validator(self, attribute, value): if not self.min.shape == self.max.shape: raise ValueError( f"while validating {attribute.name}: 'min' and 'max' must " f"have the same shape (got {self.min.shape} and {self.max.shape})" ) if not np.all(np.less(self.min, self.max)): raise ValueError( f"while validating {attribute.name}: 'min' must be strictly " "less than 'max'" )
[docs] @classmethod def convert( cls, value: t.Union[t.Sequence, t.Mapping, np.typing.ArrayLike, pint.Quantity] ) -> t.Any: """ Attempt conversion of a value to a :class:`BoundingBox`. Parameters ---------- value Value to convert. Returns ------- any If `value` is an array-like, a quantity or a mapping, conversion will be attempted. Otherwise, `value` is returned unmodified. """ if isinstance(value, (np.ndarray, pint.Quantity)): return cls(value[0, :], value[1, :]) elif isinstance(value, Sequence): return cls(*value) elif isinstance(value, Mapping): return cls(**pinttr.interpret_units(value, ureg=ureg)) else: return value
@property def shape(self): """ tuple: Shape of `min` and `max` arrays. """ return self.min.shape @property def extents(self) -> pint.Quantity: """ :class:`pint.Quantity`: Extent in all dimensions. """ return self.max - self.min @property def units(self): """ :class:`pint.Unit`: Units of `min` and `max` arrays. """ return self.min.units
[docs] def contains(self, p: np.typing.ArrayLike, strict: bool = False) -> bool: """ Test whether a point lies within the bounding box. Parameters ---------- p : quantity or array-like An array of shape (3,) (resp. (N, 3)) representing one (resp. N) points. If a unitless value is passed, it is interpreted as ``ucc["length"]``. strict : bool If ``True``, comparison is done using strict inequalities (<, >). Returns ------- result : array of bool or bool ``True`` iff ``p`` in within the bounding box. """ p = np.atleast_2d(ensure_units(p, ucc.get("length"))) cmp = ( np.logical_and(p > self.min, p < self.max) if strict else np.logical_and(p >= self.min, p <= self.max) ) return np.all(cmp, axis=1)