Source code for pytensor.tensor.linalg.constructors

import pytensor
from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import basic as ptb
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg._lazy import scipy_linalg
from pytensor.tensor.linalg.dtype_utils import _largest_common_dtype
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import unzip


class BaseBlockDiagonal(Op):
    __props__: tuple[str, ...] = ("n_inputs",)

    def __init__(self, n_inputs):
        input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs))
        self.gufunc_signature = f"{input_sig}->(m,n)"

        if n_inputs <= 1:
            raise ValueError("n_inputs must be greater than 1")
        self.n_inputs = n_inputs

    def pullback(self, inputs, outputs, gout):
        shapes = pt.stack([i.shape for i in inputs])
        index_end = shapes.cumsum(0)
        index_begin = index_end - shapes
        slices = [
            ptb.ix_(
                pt.arange(index_begin[i, 0], index_end[i, 0]),
                pt.arange(index_begin[i, 1], index_end[i, 1]),
            )
            for i in range(len(inputs))
        ]
        return [gout[0][slc] for slc in slices]

    def infer_shape(self, fgraph, nodes, shapes):
        first, second = unzip(shapes, n=2, strict=True)
        return [(pt.add(*first), pt.add(*second))]

    def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
        if len(matrices) != self.n_inputs:
            raise ValueError(
                f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}"
            )
        matrices = list(map(as_tensor_func, matrices))
        if any(mat.type.ndim != 2 for mat in matrices):
            raise TypeError("All inputs must have dimension 2")
        return matrices


class BlockDiagonal(BaseBlockDiagonal):
    __props__ = ("n_inputs",)

    def make_node(self, *matrices):
        matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
        dtype = _largest_common_dtype(matrices)

        shapes_by_dim = tuple(zip(*(m.type.shape for m in matrices)))
        out_shape = tuple(
            [
                sum(dim_shapes)
                if not any(shape is None for shape in dim_shapes)
                else None
                for dim_shapes in shapes_by_dim
            ]
        )

        out_type = pytensor.tensor.matrix(shape=out_shape, dtype=dtype)
        return Apply(self, matrices, [out_type])

    def perform(self, node, inputs, output_storage, params=None):
        dtype = node.outputs[0].type.dtype
        output_storage[0][0] = scipy_linalg.block_diag(*inputs).astype(dtype)


[docs] def block_diag(*matrices: TensorVariable): """ Construct a block diagonal matrix from a sequence of input tensors. Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: [[A, 0, 0], [0, B, 0], [0, 0, C]] Parameters ---------- A, B, C ... : tensors Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all inputs should have at least 2 dimensins. Returns ------- out: tensor The block diagonal matrix formed from the input matrices. Examples -------- Create a block diagonal matrix from two 2x2 matrices: ..code-block:: python import numpy as np from pytensor.tensor.linalg import block_diag A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]])) result = block_diagonal(A, B, name='X') print(result.eval()) Out: array([[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 5, 6], [0, 0, 7, 8]]) """ if len(matrices) == 1: return matrices[0] _block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices))) return _block_diagonal_matrix(*matrices)