```"""Utility components for quadrature rules."""

from __future__ import annotations

from enum import Enum

import attrs
import numpy as np

from .attrs import documented, parse_docs
from .util.misc import str_summary_numpy

GAUSS_LEGENDRE = "gauss_legendre"
GAUSS_LOBATTO = "gauss_lobatto"

[docs]@parse_docs
@attrs.define(eq=False, frozen=True)
"""
A data class storing information about a quadrature rule. Nodes and weights
are defined in the [-1, 1] interval. The reference interval can be changed
using the ``interval`` argument of the :meth:`.eval_nodes` and
:meth:`.integrate` functions.
"""

doc="Quadrature type. If a string is passed, it is converted to a "
)

nodes: np.ndarray = documented(
attrs.field(converter=np.array, repr=str_summary_numpy),
type="ndarray",
)

weights: np.ndarray = documented(
attrs.field(converter=np.array, repr=str_summary_numpy),
type="ndarray",
)

@nodes.validator
@weights.validator
def _nodes_weights_validator(self, attribute, value):
if self.nodes.shape != self.weights.shape:
raise ValueError(
f"while validating {attribute.name}: nodes and weights arrays "
f"must have the same shape, got nodes.shape = {self.nodes.shape} "
f"and weights.shape = {self.weights.shape}"
)

def pretty_repr(self) -> str:
return f"{self.type.value}, {self.nodes.size} points"

[docs]    @classmethod
def gauss_legendre(cls, n: int) -> Quad:
"""
Initialize a :class:`.Quad` instance with Gauss-Legendre nodes and
weights.

Parameters
----------
n : int

Returns
-------
"""

nodes, weights = gauss_legendre(n)
return cls(
nodes=np.array(nodes, dtype=float),
weights=np.array(weights, dtype=float),
)

[docs]    @classmethod
def gauss_lobatto(cls, n: int) -> Quad:
"""
Initialize a :class:`.Quad` instance with Gauss-Lobatto nodes and
weights.

Parameters
----------
n : int

Returns
-------
"""

nodes, weights = gauss_lobatto(n)
return cls(
nodes=np.array(nodes, dtype=float),
weights=np.array(weights, dtype=float),
)

[docs]    @classmethod
def new(cls, type: str, n: int) -> Quad:
"""
Initialize a :class:`.Quad` instance of the specified type.

Parameters
----------
type : str
Quadrature rule type. Allowed values are:

* ``gauss_legendre``;
* ``gauss_lobatto``.

n : int

Returns
-------
"""
if type == "gauss_legendre":
return cls.gauss_legendre(n)

elif type == "gauss_lobatto":
return cls.gauss_lobatto(n)

else:

[docs]    def eval_nodes(
self, interval: tuple[float, float] | None = None
) -> np.typing.ArrayLike:
"""
Compute nodes scaled to a specific interval.

Parameters
----------
interval :  tuple of float, optional
Interval for which nodes are to be scaled as a 2-tuple. If ``None``,
the default [-1, 1] is used.

Returns
-------
ndarray
Scaled node values.
"""
if interval is None:
return self.nodes
a, b = interval
return 0.5 * (a + b + (b - a) * self.nodes)

[docs]    def integrate(
self, values: np.typing.ArrayLike, interval: tuple[float, float] | None
) -> float:
"""
Evaluate quadrature rule, accounting for interval scaling.

Parameters
----------
values : ndarray

interval : tuple of float, optional
Interval on which the integral is being computed as a 2-tuple.
If ``None``, the default [-1, 1] is used.

Returns
-------
float
Quadrature evaluation for the specified interval.
"""

weighted_sum = float(np.dot(self.weights, values))

if interval is None:
return weighted_sum
else:
return 0.5 * (interval[1] - interval[0]) * weighted_sum

@property
def str_summary(self) -> str:
"""
Return a summarized representation of the current instance.

Returns
-------
str
Instance summary.
"""