"""
A collection of tools which don't really fit anywhere else.
"""
from __future__ import annotations
import functools
import inspect
import re
import typing as t
from collections import OrderedDict
from numbers import Number
import numpy as np
import numpy.typing as npt
import pint
import xarray as xr
[docs]
class cache_by_id:
"""
Cache the result of a function based on the ID of its arguments.
This decorator caches the value returned by the function it wraps in order
to avoid unnecessary execution upon repeated calls with the same arguments.
Warnings
--------
The main difference with
:func:`functools.lru_cache(maxsize=1) <functools.lru_cache>` is that the
cache is referenced by positional argument IDs instead of hashes.
Therefore, this decorator can be used with NumPy arrays; but it's also
unsafe, because mutating an argument won't trigger a recompute, while it
actually shoud! **Use with great care!**
Notes
-----
* Meant to be used as a decorator.
* The wrapped function may only have positional arguments.
* Works with functions and methods.
Examples
--------
>>> @cache_by_id
... def f(x, y):
... print("Calling f")
... return x, y
>>> f(1, 2)
Calling f
(1, 2)
>>> f(1, 2)
(1, 2)
>>> f(1, 1)
Calling f
(1, 1)
>>> f(1, 1)
(1, 1)
"""
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
self._cached_value = None
self._cached_index = None
def __call__(self, *args):
index = tuple(id(arg) for arg in args)
if index != self._cached_index:
self._cached_index = index
self._cached_value = self.func(*args)
return self._cached_value
def __get__(self, instance, owner):
# See https://stackoverflow.com/questions/30104047 for full explanation
return functools.partial(self.__call__, instance)
[docs]
class LoggingContext(object):
"""
This context manager allows for a temporary override of logger settings.
"""
# from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
# implicit return of None => don't swallow exceptions
[docs]
class Singleton(type):
"""
A simple singleton implementation. See [1]_ for details.
References
-------
.. [1] `Creating a singleton in Python on Stack Overflow <https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python>`_.
Examples
--------
.. testsetup:: singleton
from eradiate.util.misc import Singleton
.. doctest:: singleton
>>> class MySingleton(metaclass=Singleton): ...
>>> my_singleton1 = MySingleton()
>>> my_singleton2 = MySingleton()
>>> my_singleton1 is my_singleton2
True
.. testcleanup:: singleton
del Singleton
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def camel_to_snake(name):
# from https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
[docs]
def deduplicate(value: t.Sequence, preserve_order: bool = True) -> list:
"""
Remove duplicates from a sequence.
Parameters
---------
value : sequence
Sequence to remove duplicates from.
preserve_order : bool, optional, default: True
If ``True``, preserve item ordering. The first occurrence of duplicated
items is kept. Setting to ``False`` may slightly improve performance.
Returns
-------
list
List of values with duplicates removed.
"""
if preserve_order:
return list(OrderedDict.fromkeys(value))
else:
return list(set(value))
def deduplicate_sorted(value: t.Sequence, cmp: t.Callable | None = None) -> list:
if cmp is None:
cmp = lambda x, y: x == y # noqa: E731
result = [value[0]]
for i in range(1, len(value)):
if not cmp(value[i], value[i - 1]):
result.append(value[i])
return result
[docs]
def flatten(d: t.Mapping, sep: str = ".", name: str = "") -> dict:
"""
Flatten a nested dictionary.
Parameters
----------
d : dict
Dictionary to be flattened.
name : str, optional, default: ""
Path to the parent dictionary. By default, no parent name is defined.
sep : str, optional, default: "."
Flattened dict key separator.
Returns
-------
dict
A flattened copy of `d`.
See Also
--------
:func:`.nest`, :func:`.set_nested`
"""
result = {}
for k, v in d.items():
full_key = k if not name else f"{name}{sep}{k}"
if isinstance(v, dict):
result.update(flatten(v, sep=sep, name=full_key))
else:
result[full_key] = v
return result
[docs]
def fullname(obj: t.Any) -> str:
"""
Get the fully qualified name of `obj`. Aliases will be dereferenced.
"""
cls = get_class_that_defined_method(obj)
if cls is None:
return f"{obj.__module__}.{obj.__qualname__}"
# else: # (it's a method)
return f"{cls.__module__}.{obj.__qualname__}"
[docs]
def get_class_that_defined_method(meth: t.Any) -> type:
"""
Get the class which defined a method, if relevant. Otherwise, return
``None``.
"""
# See https://stackoverflow.com/questions/3589311/get-defining-class-of-unbound-method-object-in-python-3/25959545#25959545
if isinstance(meth, functools.partial):
return get_class_that_defined_method(meth.func)
if inspect.ismethod(meth) or (
inspect.isbuiltin(meth)
and getattr(meth, "__self__", None) is not None
and getattr(meth.__self__, "__class__", None)
):
for cls in inspect.getmro(meth.__self__.__class__):
if meth.__name__ in cls.__dict__:
return cls
meth = getattr(meth, "__func__", meth) # fallback to __qualname__ parsing
if inspect.isfunction(meth):
cls = getattr(
inspect.getmodule(meth),
meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0],
None,
)
if isinstance(cls, type):
return cls
return getattr(meth, "__objclass__", None) # handle special descriptor objects
[docs]
def is_vector3(value: t.Any):
"""
Check if value can be interpreted as a 3-vector.
Parameters
----------
value
Value to be checked.
Returns
-------
bool
``True`` if a value can be interpreted as a 3-vector.
"""
if isinstance(value, pint.Quantity):
return is_vector3(value.magnitude)
return (
(
isinstance(value, np.ndarray)
or (isinstance(value, t.Sequence) and not isinstance(value, str))
)
and len(value) == 3
and all(map(lambda x: isinstance(x, Number), value))
)
[docs]
def natsort_alphanum_key(x):
"""
Simple sort key natural order for string sorting. See [2]_ for details.
See Also
--------
`Sorting HOWTO <https://docs.python.org/3/howto/sorting.html>`_
References
----------
.. [2] `Natural sorting on Stack Overflow <https://stackoverflow.com/a/11150413/3645374>`_.
"""
return tuple(
map(
lambda text: int(text) if text.isdigit() else text.lower(),
re.split("([0-9]+)", x),
)
)
[docs]
def natsorted(l): # noqa
"""
Sort a list of strings with natural ordering.
Parameters
----------
l : iterable
List to sort.
Returns
-------
list
List sorted using :func:`natsort_alphanum_key`.
"""
return sorted(l, key=natsort_alphanum_key)
[docs]
def nest(d: t.Mapping, sep: str = ".") -> dict:
"""
Turn a flat dictionary into a nested dictionary.
Parameters
----------
d : dict
Dictionary to be unflattened.
sep : str, optional, default: "."
Flattened dict key separator.
Returns
-------
dict
A nested copy of `d`.
See Also
--------
:func:`.flatten`, :func:`.set_nested`
"""
result = {}
for key, value in d.items():
set_nested(result, key, value, sep)
return result
[docs]
def onedict_value(d: t.Mapping) -> t.Any:
"""
Get the value of a single-entry dictionary.
Parameters
----------
d : mapping
A single-entry mapping.
Returns
-------
object
Unwrapped value.
Raises
------
ValueError
If ``d`` has more than a single element.
Notes
-----
This function is basically ``next(iter(d.values()))`` with a safeguard.
Examples
--------
.. testsetup:: onedict_value
from eradiate.util.misc import onedict_value
.. doctest:: onedict_value
>>> onedict_value({"foo": "bar"})
'bar'
.. testcleanup:: onedict_value
del onedict_value
"""
if len(d) != 1:
raise ValueError(f"dictionary has wrong length (expected 1, got {len(d)}")
return next(iter(d.values()))
def round_to_multiple(number, multiple, direction="nearest"):
if direction == "nearest":
return multiple * round(number / multiple)
elif direction == "up":
return multiple * np.ceil(number / multiple)
elif direction == "down":
return multiple * np.floor(number / multiple)
else:
return multiple * round(number / multiple)
[docs]
def set_nested(d: t.Mapping, path: str, value: t.Any, sep: str = ".") -> None:
"""
Set values in a nested dictionary using a flat path.
Parameters
----------
d : dict
Dictionary to operate on.
path : str
Path to the value to be set.
value
Value to which `path` is to be set.
sep : str, optional, default: "."
Separator used to decompose `path`.
See Also
--------
:func:`.flatten`, :func:`.nest`
"""
*path, last = path.split(sep)
for bit in path:
d = d.setdefault(bit, {})
d[last] = value
def str_summary_numpy(x):
with np.printoptions(
threshold=4, edgeitems=2, formatter={"float_kind": lambda x: f"{x:g}"}
):
shape_str = ",".join(map(str, x.shape))
prefix = f"array<{shape_str}>("
array_str = f"{x}"
# Indent repr if it is multiline
split = array_str.split("\n")
if len(split) > 1:
array_str = ("\n" + " " * len(prefix)).join(split)
return f"{prefix}{array_str})"
[docs]
@functools.singledispatch
def summary_repr(value):
"""
Return a summarized repr for `value`.
"""
return repr(value)
@summary_repr.register
def _(ds: xr.Dataset):
extra_info = {}
try:
extra_info["source"] = repr(ds.encoding["source"])
except KeyError:
pass
desc = ", ".join([f"{key}={value}" for key, value in extra_info.items()])
if desc:
desc = " | " + desc
return f"<xarray.Dataset{desc}>"
@summary_repr.register
def _(da: xr.DataArray):
extra_info = {}
try:
extra_info["name"] = repr(da.name)
except AttributeError:
pass
extra_info["dims"] = repr(list(da.dims))
try:
extra_info["source"] = repr(da.encoding["source"])
except KeyError:
pass
desc = ", ".join([f"{key}={value}" for key, value in extra_info.items()])
if desc:
desc = " | " + desc
return f"<xarray.DataArray{desc}>"
@summary_repr.register
def _(x: pint.Quantity):
"""
Return a brief summary representation of a Pint quantity.
"""
return f"{summary_repr_vector(x.m)} {x.u:~}"
[docs]
def summary_repr_vector(a: np.ndarray, edgeitems: int = 4):
"""
Return a brief summary representation of a Numpy vector.
"""
size = len(a)
if size > edgeitems * 2 + 1:
return (
f"[{np.array2string(a[:edgeitems]).strip('[]')}"
" ... "
f"{np.array2string(a[size - edgeitems:]).strip('[]')}]"
)
else:
return np.array2string(a)
[docs]
def find_runs(
x: npt.ArrayLike,
) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]:
"""
Find runs of consecutive items in an array.
Parameters
----------
x : array-like
Input array.
Returns
-------
tuple(array-like, array-like, array-like)
Run values, run starts, run lengths.
Notes
-----
Credit: Alistair Miles
Source: https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065
"""
# ensure array
x = np.asanyarray(x)
if x.ndim != 1:
raise ValueError("only 1D array supported")
n = x.shape[0]
# handle empty array
if n == 0:
return np.array([]), np.array([]), np.array([])
else:
# find run starts
loc_run_start = np.empty(n, dtype=bool)
loc_run_start[0] = True
np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
run_starts = np.nonzero(loc_run_start)[0]
# find run values
run_values = x[loc_run_start]
# find run lengths
run_lengths = np.diff(np.append(run_starts, n))
return run_values, run_starts, run_lengths
[docs]
class MultiGenerator:
"""
This generator aggregates several generators and makes sure that items that
have already been served are not repeated.
"""
def __init__(self, generators):
self.generators = generators
self._i_generator = 0
self._current_iterator = iter(self.generators[self._i_generator])
self._visited = set()
def __iter__(self):
return self
def __next__(self):
try:
result = next(self._current_iterator)
if result not in self._visited:
self._visited.add(result)
return result
else:
return self.__next__()
except StopIteration:
if self._i_generator >= len(self.generators) - 1:
raise
else:
self._i_generator += 1
self._current_iterator = iter(self.generators[self._i_generator])
return self.__next__()