Source code for pytensor.tensor.linalg.decomposition.eigen

import warnings
from typing import cast

import numpy as np

from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable, diag, eye, tril, triu
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg._lazy import scipy_linalg
from pytensor.tensor.linalg.dtype_utils import linalg_real_output_dtype
from pytensor.tensor.math import sub, switch
from pytensor.tensor.type import Variable, tensor, vector
from pytensor.tensor.type_other import NoneTypeT


def _zero_disconnected(outputs, grads):
    l = []
    for o, g in zip(outputs, grads, strict=True):
        if isinstance(g.type, DisconnectedType):
            l.append(o.zeros_like())
        else:
            l.append(g)
    return l


class Eig(Op):
    """
    Compute the eigenvalues and right eigenvectors of a square array.
    """

    __props__: tuple[str, ...] = ()
    # Can't use numpy directly in Blockwise, because of the dynamic dtype
    # gufunc_spec = ("numpy.linalg.eig", 1, 2)
    gufunc_signature = "(m,m)->(m),(m,m)"

    def make_node(self, x):
        x = as_tensor_variable(x)
        assert x.ndim == 2

        M, N = x.type.shape

        if M is not None and N is not None and M != N:
            raise ValueError(
                f"Input to Eig must be a square matrix, got static shape: ({M}, {N})"
            )

        dtype = np.promote_types(x.dtype, np.complex64)

        w = tensor(dtype=dtype, shape=(M,))
        v = tensor(dtype=dtype, shape=(M, N))

        return Apply(self, [x], [w, v])

    def perform(self, node, inputs, outputs):
        (x,) = inputs
        dtype = np.promote_types(x.dtype, np.complex64)

        w, v = np.linalg.eig(x)

        # If the imaginary part of the eigenvalues is zero, numpy automatically casts them to real. We require
        # a statically known return dtype, so we have to cast back to complex to avoid dtype mismatch.
        outputs[0][0] = w.astype(dtype, copy=False)
        outputs[1][0] = v.astype(dtype, copy=False)

    def infer_shape(self, fgraph, node, shapes):
        (x_shapes,) = shapes
        n, _ = x_shapes

        return [(n,), (n, n)]

    def pullback(self, inputs, outputs, output_grads):
        raise NotImplementedError(
            "Gradients for Eig is not implemented because it always returns complex values, "
            "for which autodiff is not yet supported in PyTensor (PRs welcome :) ).\n"
            "If you know that your input has strictly real-valued eigenvalues (e.g. it is a "
            "symmetric matrix), use pt.linalg.eigh instead."
        )


def eig(x: TensorLike):
    """
    Return the eigenvalues and right eigenvectors of a square array.

    Note that regardless of the input dtype, the eigenvalues and eigenvectors are returned as complex numbers. As a
    result, the gradient of this operation is not implemented (because PyTensor does not support autodiff for complex
    values yet).

    If you know that your input has strictly real-valued eigenvalues (e.g. it is a symmetric matrix), use
    `pytensor.tensor.linalg.eigh` instead.

    Parameters
    ----------
    x: TensorLike
        Square matrix, or array of such matrices
    """
    return Blockwise(Eig())(x)


class Eigh(Op):
    """
    Return the eigenvalues and eigenvectors of a Hermitian or symmetric matrix.

    Optionally solves the generalized eigenvalue problem ``A @ v = w * B @ v``
    when a second matrix *b* is provided (delegated to ``scipy.linalg.eigh``).
    """

    __props__ = ("lower", "overwrite_a", "overwrite_b", "driver")

    def __init__(
        self,
        lower: bool = True,
        UPLO: str | None = None,
        overwrite_a: bool = False,
        overwrite_b: bool = False,
        driver: str = "evr",
    ):
        if UPLO is not None:
            warnings.warn(
                "UPLO is deprecated and will be removed in a future version. Use the ``lower`` argument "
                "instead.",
                stacklevel=2,
                category=DeprecationWarning,
            )
            lower = UPLO == "L"

        if driver not in ("evr", "evd"):
            raise ValueError(
                f"Invalid driver: {driver!r}. Must be one of 'evr', 'evd'."
            )

        if overwrite_a and overwrite_b:
            raise ValueError(
                "overwrite_a and overwrite_b are mutually exclusive: pytensor "
                "tracks at most one destroyed input per output."
            )

        self.lower = lower
        self.overwrite_a = overwrite_a
        self.overwrite_b = overwrite_b
        self.driver = driver

        # Output 1 (eigenvectors) is the one that lands in the destroyed buffer.
        if self.overwrite_a:
            self.destroy_map = {1: [0]}
        elif self.overwrite_b:
            self.destroy_map = {1: [1]}

    def make_node(self, a, b=None):
        a = as_tensor_variable(a)
        assert a.ndim == 2
        M, N = a.type.shape

        if M is not None and N is not None and M != N:
            raise ValueError(
                f"Input to Eigh must be a square matrix, got static shape: ({M}, {N})"
            )

        has_b = b is not None and not (
            isinstance(b, Variable) and isinstance(b.type, NoneTypeT)
        )

        if has_b:
            b = as_tensor_variable(b)
            inputs = [a, b]
        else:
            inputs = [a]

        w_dtype = linalg_real_output_dtype(*[x.type.dtype for x in inputs])

        w = tensor(dtype=w_dtype, shape=(M,))
        v = tensor(dtype=w_dtype, shape=(M, N))

        return Apply(self, inputs, [w, v])

    def perform(self, node, inputs, outputs):
        (w, v) = outputs
        if len(inputs) == 2:
            # Generalized eigenproblem: scipy doesn't accept driver= with b
            w[0], v[0] = scipy_linalg.eigh(
                inputs[0],
                b=inputs[1],
                lower=self.lower,
                overwrite_a=self.overwrite_a,
                overwrite_b=self.overwrite_b,
            )
        else:
            w[0], v[0] = scipy_linalg.eigh(
                inputs[0],
                lower=self.lower,
                overwrite_a=self.overwrite_a,
                driver=self.driver,
            )

    def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
        # overwrite_a and overwrite_b are mutually exclusive; prefer overwrite_a
        # arbitrarily (memory savings are identical)
        new_props = self._props_dict()  # type: ignore
        if 0 in allowed_inplace_inputs:
            new_props["overwrite_a"] = True
        elif 1 in allowed_inplace_inputs:
            new_props["overwrite_b"] = True
        else:
            return self
        return type(self)(**new_props)

    def infer_shape(self, fgraph, node, shapes):
        n = shapes[0][0]
        return [(n,), (n, n)]

    def pullback(self, inputs, outputs, output_grads):
        r"""Symbolic gradient of ``eigh``.

        For the standard symmetric problem,

        .. math::

            A V = V \operatorname{diag}(w), \qquad V^T V = I,

        define

        .. math::

            F_{ij} =
            \begin{cases}
                \frac{1}{w_j - w_i}, & i \ne j, \\
                0, & i = j .
            \end{cases}

        Then the pullback is

        .. math::

            C = V^T g_V,
            \qquad
            M = \operatorname{diag}(g_w) + F \odot C,
            \qquad
            g_A = V M V^T.

        For the generalized symmetric-definite problem,

        .. math::

            A V = B V \operatorname{diag}(w), \qquad V^T B V = I,

        the pullback is

        .. math::

            C = V^T g_V,
            \qquad
            M = \operatorname{diag}(g_w) + F \odot C,

        .. math::

            g_A = V M V^T,

        .. math::

            g_B =
            -V \left(M \operatorname{diag}(w)\right) V^T
            - \frac12 V \operatorname{diag}(\operatorname{diag}(C)) V^T.

        The gradients are symmetrized on return to match the triangular storage
        specified by ``UPLO``.

        These formulas assume distinct eigenvalues. When eigenvalues are repeated,
        the factors ``1 / (w_j - w_i)`` are singular and the eigenvector gradient is
        not uniquely defined.
        """
        w, v = outputs
        gw, gv = _zero_disconnected([w, v], output_grads)

        # F_ij = 1/(w_j - w_i) for i != j, 0 on diagonal
        w_diff = sub.outer(w, w).T
        F = switch(eye(w.shape[0], dtype="bool"), 0.0, 1.0 / w_diff)

        if len(inputs) == 1:
            inner = diag(gw) + F * (v.T @ gv)
            g = v @ inner @ v.T

            if self.lower:
                out = tril(g) + triu(g, k=1).T
            else:
                out = triu(g) + tril(g, k=-1).T
            return [out]
        else:
            C = v.T @ gv
            inner = diag(gw) + F * C

            ga = v @ inner @ v.T
            gb = -v @ (inner * w[None, :]) @ v.T
            gb = gb - 0.5 * v @ diag(diag(C)) @ v.T

            if self.lower:
                ga_sym = tril(ga) + triu(ga, k=1).T
                gb_sym = tril(gb) + triu(gb, k=1).T
            else:
                ga_sym = triu(ga) + tril(ga, k=-1).T
                gb_sym = triu(gb) + tril(gb, k=-1).T
            return [ga_sym, gb_sym]


[docs] def eigh( a: TensorLike, b: TensorLike | None = None, lower: bool = True, UPLO: str | None = None, driver: str = "evr", ) -> list[Variable]: """ Return the eigenvalues and eigenvectors of a symmetric/Hermitian matrix. Parameters ---------- a : TensorLike Symmetric/Hermitian matrix (or batch thereof). b : TensorLike, optional Second matrix for the generalized eigenvalue problem ``A v = w B v``. Must be positive-definite. If ``None``, the standard eigenvalue problem is solved. lower : bool Whether to use the lower or upper triangle of a (and b, if provided). Default is True UPLO : {'L', 'U'}, optional Whether to use the lower or upper triangle of a (and b, if provided). Default is 'L' (lower). UPLO is deprecated and will be removed in a future version. Use the ``lower`` argument instead. driver : {'evr', 'evd'}, optional LAPACK driver to use. ``'evr'`` (default) uses the MRRR algorithm, the fastest general-purpose driver. This is the default used by Scipy. ``'evd'`` uses divide-and-conquer, matching NumPy, JAX, and MLX. Returns ------- w : Variable Eigenvalues of the system, in ascending order. v : Variable Eigenvectors of the system, in ascending order. """ if UPLO is not None: warnings.warn( "UPLO is deprecated and will be removed in a future version. ", stacklevel=2, category=DeprecationWarning, ) lower = UPLO == "L" if b is None: signature = "(m,m)->(m),(m,m)" return cast( list[Variable], Blockwise(Eigh(lower=lower, driver=driver), signature=signature)(a), ) # Generalized eigenproblem always uses divide-and-conquer signature = "(m,m),(m,m)->(m),(m,m)" return cast( list[Variable], Blockwise(Eigh(lower=lower, driver="evd"), signature=signature)(a, b), )
class Eigvalsh(Op): """ Generalized eigenvalues of a Hermitian positive definite eigensystem. """ __props__ = ("lower", "overwrite_a", "overwrite_b") def __init__(self, lower=True, overwrite_a=False, overwrite_b=False): assert lower in [True, False] if overwrite_a and overwrite_b: raise ValueError( "overwrite_a and overwrite_b are mutually exclusive: pytensor " "tracks at most one destroyed input per output. " ) self.lower = lower self.overwrite_a = overwrite_a self.overwrite_b = overwrite_b if overwrite_a: self.destroy_map = {0: [0]} elif overwrite_b: self.destroy_map = {0: [1]} def make_node(self, a, b=None): a = as_tensor_variable(a) assert a.ndim == 2 M, N = a.type.shape if M is not None and N is not None and M != N: raise ValueError( f"Input to eigvalsh must be square, got {a} with shape ({M}, {N})" ) if b is None or (isinstance(b, Variable) and isinstance(b.type, NoneTypeT)): if self.overwrite_b: raise ValueError( "overwrite_b=True requires the generalized form with a second input" ) inputs = [a] probe_dtype = a.type.dtype else: b = as_tensor_variable(b) assert a.ndim == 2 assert b.ndim == 2 probe_dtype = np.result_type(a.type.dtype, b.type.dtype) inputs = [a, b] # Probe scipy for the output dtype (eigenvalues are always real) probe = np.zeros((1, 1), dtype=probe_dtype) out_dtype = scipy_linalg.eigvalsh(probe).dtype.name w = vector(dtype=out_dtype, shape=(N,)) return Apply(self, inputs, [w]) def infer_shape(self, fgraph, node, shapes): n = shapes[0][0] return [ (n,), ] def perform(self, node, inputs, outputs): (w,) = outputs if len(inputs) == 2: w[0] = scipy_linalg.eigvalsh( a=inputs[0], b=inputs[1], lower=self.lower, overwrite_a=self.overwrite_a, overwrite_b=self.overwrite_b, ) else: w[0] = scipy_linalg.eigvalsh( a=inputs[0], b=None, lower=self.lower, overwrite_a=self.overwrite_a, ) def pullback(self, inputs, outputs, g_outputs): (gw,) = g_outputs if len(inputs) == 1: (a,) = inputs w, v = eigh(a, lower=self.lower) gA = v @ diag(gw) @ v.T if self.lower: gA = tril(gA) + triu(gA, k=1).T else: gA = triu(gA) + tril(gA, k=-1).T return [gA] else: a, b = inputs w, v = eigh(a, b, lower=self.lower) gA = v @ diag(gw) @ v.T gB = -v @ diag(gw * w) @ v.T if self.lower: gA = tril(gA) + triu(gA, k=1).T gB = tril(gB) + triu(gB, k=1).T else: gA = triu(gA) + tril(gA, k=-1).T gB = triu(gB) + tril(gB, k=-1).T return [gA, gB] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": # overwrite_a and overwrite_b are mutually exclusive (PyTensor tracks at most one destroyed # input per output). When both can be destroyed, we prefer overwrite_a. new_props = self._props_dict() # type: ignore if 0 in allowed_inplace_inputs: new_props["overwrite_a"] = True elif 1 in allowed_inplace_inputs: new_props["overwrite_b"] = True else: return self return type(self)(**new_props) def eigvalsh( a: TensorLike, b: TensorLike | None = None, lower: bool = True, ) -> Variable: """ Compute the eigenvalues of a symmetric/Hermitian matrix. This is identical to ``eigh(a, b, lower)[0]``, but more efficient when only the eigenvalues are needed. Parameters ---------- a : TensorLike Symmetric/Hermitian matrix (or batch thereof). b : TensorLike, optional Second matrix for the generalized eigenvalue problem ``A v = w B v``. Must be positive-definite. If ``None``, the standard eigenvalue problem is solved. lower : bool, optional Whether to use the lower or upper triangle of a (and b). Default True. Returns ------- w : TensorVariable Eigenvalues of the system, in ascending order. """ op = Eigvalsh(lower=lower) if b is None: signature = "(m,m)->(m)" return cast(Variable, Blockwise(op, signature=signature)(a)) signature = "(m,m),(m,m)->(m)" return cast(Variable, Blockwise(op, signature=signature)(a, b))