Source code for eradiate.kernel._render

from __future__ import annotations

import logging
import typing as t
import warnings

import attrs
import drjit as dr
import mitsuba as mi
from mitsuba.python.util import SceneParameters as _MitsubaSceneParameters
from tqdm.auto import tqdm

from ._kernel_dict import UpdateMapTemplate
from .. import config
from ..attrs import define, documented, frozen
from ..contexts import KernelContext
from ..rng import SeedState, root_seed_state

logger = logging.getLogger(__name__)


# ------------------------------------------------------------------------------
#                         Parameter lookup strategies
# ------------------------------------------------------------------------------


[docs] @frozen class TypeIdLookupStrategy: """ This parameter ID lookup strategy searches for a Mitsuba type and object ID match. Instances are callables which take, as argument, the current node during a Mitsuba scene tree traversal and, optionally, its path in the Mitsuba scene tree. If the lookup succeeds, the full parameter path is returned. """ node_type: type = documented( attrs.field(validator=attrs.validators.instance_of(type)), doc="Type of the node which will be looked up.", type="type", ) node_id: str = documented( attrs.field(validator=attrs.validators.instance_of(str)), doc="ID of the node which will be looked up.", type="str", ) parameter_relpath: str = documented( attrs.field(validator=attrs.validators.instance_of(str)), doc="Parameter path relative to its parent object.", type="str", ) def __call__(self, node, node_path: str | None = None) -> str | None: if isinstance(node, self.node_type) and node.id() == self.node_id: prefix = f"{node_path}." if node_path is not None else "" return f"{prefix}{self.parameter_relpath}" else: return None
# ------------------------------------------------------------------------------ # Mitsuba scene traversal # ------------------------------------------------------------------------------
[docs] @define class MitsubaObjectWrapper: """ This container aggregates a Mitsuba object, its associated parameters and a set of updaters that can be used to modify the scene parameters. See Also -------- :func:`mi_traverse` """ obj: mi.Object = documented( attrs.field(repr=lambda x: "Scene[...]" if isinstance(x, mi.Scene) else str(x)), doc="Mitsuba object.", type="mitsuba.Object", ) parameters: mi.SceneParameters | None = documented( attrs.field( default=None, repr=lambda x: ( "SceneParameters[...]" if x.__class__.__name__ == "SceneParameters" else str(None) ), ), doc="Mitsuba scene parameter map.", type="mitsuba.SceneParameters", init_type="mitsuba.SceneParameters, optional", default="None", ) umap_template: UpdateMapTemplate | None = documented( attrs.field( default=None, repr=lambda x: ( "UpdateMapTemplate[...]" if isinstance(x, UpdateMapTemplate) else str(x) ), ), doc="An update map template, which can be rendered and used to update " "Mitsuba scene parameters depending on context information.", type=".UpdateMapTemplate", init_type=".UpdateMapTemplate, optional", default="None", )
[docs] def drop_parameters(self) -> None: """ Reduce the size of the scene parameter table :attr:`.parameters` by only keeping elements whose keys are listed in the parameter update map template :attr:`.umap_template`. For parameters associated with a lookup protocol, the looked up parameter ID is checked and used. """ if self.umap_template is not None: keys = [] for name, param in self.umap_template.items(): if param.lookup_strategy is not None: if param.parameter_id is not None: keys.append(param.parameter_id) else: warnings.warn( f"Parameter '{name}' has a lookup strategy but the " "associated parameter ID is undefined; was a " "parameter lookup performed during the Mitsuba " "scene traversal?" ) else: keys.append(name) self.parameters.keep(keys)
class SceneParameters(_MitsubaSceneParameters): def __init__(self, properties=None, hierarchy=None, aliases=None): super().__init__(properties, hierarchy) self.aliases = aliases if aliases is not None else {} def set_dirty(self, key: str): # Inherit docstring value, _, node, flags = self.properties[key] is_nondifferentiable = flags & mi.ParamFlags.NonDifferentiable.value if is_nondifferentiable and dr.grad_enabled(value): mi.Log( mi.LogLevel.Warn, f"Parameter '{key}' is marked as non-differentiable but has " "gradients enabled, unexpected results may occur!", ) node_key = key # Key of current node while node is not None: parent, depth = self.hierarchy[node] name = node_key if parent is not None: if "." not in name and depth > 0: # We've hit the top level from an ID-aliased node: # Resolve the alias to finish climbing the hierarchy node_key = self.aliases[name] node_key, name = node_key.rsplit(".", 1) self.nodes_to_update.setdefault((depth, node), set()) self.nodes_to_update[(depth, node)].add(name) node = parent return self.properties[key]
[docs] def mi_traverse( obj: mi.Object, umap_template: UpdateMapTemplate | None = None, name_id_override: str | list[str] | bool | None = None, ) -> MitsubaObjectWrapper: """ Traverse a node of the Mitsuba scene graph and return a container holding the Mitsuba scene, its parameter map and an updated parameter update map. Parameters ---------- obj : mitsuba.Object Mitsuba scene graph node to be traversed. umap_template : .UpdateMapTemplate, optional An additional update map template which is to be updated during traversal. This is used to perform parameter lookup during traversal. name_id_override : str or list of str, optional If set, this argument will be used to select nodes in the scene tree whose names will be "pinned" to their ID. Passed values are used as regular expressions, with all that it implies regarding ID string matching. If this parameter is set to ``True``, a regex that matches anything is used. Returns ------- MitsubaObjectWrapper A container holding the traversed object, the corresponding parameter map and the parameter update map (if any). Notes ----- This is a reimplementation of the :func:`mitsuba.traverse` function. """ umap_template = ( UpdateMapTemplate(data=umap_template.data.copy()) if umap_template is not None else UpdateMapTemplate() ) lookups = { k: v for k, v in umap_template.items() if v.parameter_id is None and v.lookup_strategy is not None } if name_id_override is None or name_id_override is False: name_id_override = [] if name_id_override is True: name_id_override = [r".*"] if type(name_id_override) is not list: name_id_override = [name_id_override] import re regexps = [re.compile(k).match for k in name_id_override] class SceneTraversal(mi.TraversalCallback): def __init__( self, node, parent=None, properties=None, hierarchy=None, prefixes=None, name=None, depth=0, flags=+mi.ParamFlags.Differentiable, aliases=None, ): mi.TraversalCallback.__init__(self) self.properties = dict() if properties is None else properties self.hierarchy = dict() if hierarchy is None else hierarchy self.prefixes = set() if prefixes is None else prefixes self.aliases = dict() if aliases is None else aliases node_id = node.id() if name_id_override and node_id: for r in regexps: if r(node_id): if node_id != name: self.aliases[node_id] = name name = node_id break if name is not None: ctr, name_len = 1, len(name) while name in self.prefixes: name = f"{name[:name_len]}_{ctr}" ctr += 1 self.prefixes.add(name) self.name = name self.node = node self.depth = depth self.hierarchy[node] = (parent, depth) self.flags = flags # Try and recover a parameter ID from this node for name, uparam in list(lookups.items()): lookup_result = uparam.lookup_strategy(self.node, self.name) if lookup_result is not None: uparam.parameter_id = lookup_result del lookups[ name ] # Remove successful lookups to accelerate future searches def put_parameter(self, name, ptr, flags, cpptype=None): name = name if self.name is None else self.name + "." + name flags = self.flags | flags # Non-differentiable parameters shouldn't be flagged as discontinuous if (flags & mi.ParamFlags.NonDifferentiable) != 0: flags = flags & ~mi.ParamFlags.Discontinuous self.properties[name] = (ptr, cpptype, self.node, self.flags | flags) def put_object(self, name, node, flags): if node is None or node in self.hierarchy: return cb = SceneTraversal( node=node, parent=self.node, properties=self.properties, hierarchy=self.hierarchy, prefixes=self.prefixes, name=name if self.name is None else f"{self.name}.{name}", depth=self.depth + 1, flags=self.flags | flags, aliases=self.aliases, ) node.traverse(cb) cb = SceneTraversal(obj) obj.traverse(cb) # Check if there are unsuccessful lookups if lookups: warnings.warn( "There were unsuccessful Mitsuba scene parameter lookups: " f"{list(lookups.keys())}" ) return MitsubaObjectWrapper( obj=obj, parameters=SceneParameters(cb.properties, cb.hierarchy, cb.aliases), umap_template=umap_template, )
# ------------------------------------------------------------------------------ # Mitsuba scene render # ------------------------------------------------------------------------------
[docs] def mi_render( mi_scene: MitsubaObjectWrapper, ctxs: list[KernelContext], sensors: None | int | list[int] = None, spp: int = 0, seed_state: SeedState | None = None, ) -> dict[t.Any, mi.Bitmap]: """ Render a Mitsuba scene multiple times given specified contexts and sensor indices. Parameters ---------- mi_scene : .MitsubaObjectWrapper Mitsuba scene to render. ctxs : list of :class:`.KernelContext` List of contexts used to generate the parameter update table at each iteration. sensors : int or list of int, optional Sensor indices to render. If ``None`` (default), all sensors are rendered. spp : int, optional, default: 0 Number of samples per pixel. If set to 0 (default), the value set in the original scene definition takes precedence. seed_state : .SeedState, optional Seed state used to generate seeds to initialize Mitsuba's RNG at each run. If unset, Eradiate's root seed state is used. Returns ------- dict A nested dictionary mapping context and sensor indices to rendered bitmaps. Notes ----- This function wraps sequential calls to :func:`mitsuba.render`. """ if seed_state is None: logger.debug("Using default RNG seed generator") seed_state = root_seed_state results = {} # Loop on contexts with tqdm( initial=0, total=len(ctxs), unit_scale=1.0, leave=True, bar_format="{desc}{n:g}/{total:g}|{bar}| {elapsed}, ETA={remaining}", disable=(config.settings.progress < config.ProgressLevel.SPECTRAL_LOOP) or len(ctxs) <= 1, ) as pbar: for ctx in ctxs: pbar.set_description( f"Eradiate [{ctx.index_formatted}]", refresh=True, ) logger.debug("Updating Mitsuba scene parameters") mi_scene.parameters.update(mi_scene.umap_template.render(ctx)) if sensors is None: mi_sensors = [ (i, sensor) for i, sensor in enumerate(mi_scene.obj.sensors()) ] else: if isinstance(sensors, int): sensors = [sensors] mi_sensors = [(i, mi_scene.obj.sensors()[i]) for i in sensors] # Loop on sensors for i_sensor, mi_sensor in mi_sensors: # Render sensor seed = int(seed_state.next().squeeze()) logger.debug( 'Running Mitsuba for sensor "%s" with seed value %s', mi_sensor.id(), seed, ) mi.render(mi_scene.obj, sensor=i_sensor, seed=seed, spp=spp) # Store result in a new Bitmap object siah = ctx.si.as_hashable if siah not in results: results[siah] = {} results[siah][mi_sensor.id()] = mi.Bitmap(mi_sensor.film().bitmap()) pbar.update() return results