From cc9f605bbf25ce7e1da4e980322175ab54107775 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Sun, 29 Oct 2023 13:39:24 -0400 Subject: [PATCH] [DOC] Polish internal documentation (#36) * [REF] Make `_tensors_to_sync` abstract and strictly return tuples This simplifies the implementation for recursive matrices and moves the `NotImplementedError` error for data parallel support closer to the method that needs to be implemented. * [REF] Implement `{TriuTopLeft,TriuBottomRight}DiagonalMatrix` recursively * [DEL] Redundant documentation * [REF] Extract function that checks for square-matrix input * [REF] Extract function to compute the boundary index * [FIX] Documentation formatting and correctness * [FIX] Typo in error message * [REF] Use `block_diag` * [REF] Extract `_get_boundary` * [REF] Implement `Tril{BottomRight,TopLeft}DiagonalMatrix` via recursion * [DOC] Add internal sections, start polishing docstrings * [DOC] Polish docstring of `BlockDiagonalMatrixTemplate` * [DOC] Add example demonstrating usage of `BlockDiagonalMatrixTemplate` * [DOC] Polish docstring of `HierarchicalMatrixTemplate` * [DOC] Add code examples * [DOC] Warn that symmetry is not verifies internally * [DOC] Add interface to internal section * [DOC] Polish docstrings of matrix structures * [REF] Replace all `.dim()` with `.ndim` * [FIX] Darglint --- docs/interface.md | 7 ++ docs/structures.md | 59 ++++++++++ docs/templates.md | 24 +++++ makefile | 5 +- mkdocs.yml | 6 +- singd/structures/base.py | 69 ++++++------ singd/structures/blockdiagonal.py | 138 +++++++++++++++--------- singd/structures/dense.py | 28 +++-- singd/structures/diagonal.py | 29 ++++- singd/structures/hierarchical.py | 129 +++++++++++++++------- singd/structures/recursive.py | 136 +++++++++++++++++------ singd/structures/trilbottomrightdiag.py | 23 ++-- singd/structures/triltoeplitz.py | 48 ++++++--- singd/structures/triltopleftdiag.py | 17 +-- singd/structures/triubottomrightdiag.py | 23 ++-- singd/structures/triutoeplitz.py | 34 ++++-- singd/structures/triutopleftdiag.py | 23 ++-- singd/structures/utils.py | 2 +- 18 files changed, 577 insertions(+), 223 deletions(-) create mode 100644 docs/interface.md create mode 100644 docs/structures.md create mode 100644 docs/templates.md diff --git a/docs/interface.md b/docs/interface.md new file mode 100644 index 0000000..e8999e8 --- /dev/null +++ b/docs/interface.md @@ -0,0 +1,7 @@ +This section lists the interface for structured matrices, that is the operations +they need to implement to work in SINGD. It serves **for internal purposes +only**. This is useful for developers that wish to add a new structured matrix +class to the code that cannot be constructed with one of the available +templates. + +::: singd.structures.base.StructuredMatrix diff --git a/docs/structures.md b/docs/structures.md new file mode 100644 index 0000000..82c8657 --- /dev/null +++ b/docs/structures.md @@ -0,0 +1,59 @@ +Here we provide a list of structured matrices. This list is meant **for internal +purposes only**. It exists because it is more convenient to read the rendered +LaTeX code rather than the docstring source. + +::: singd.structures.dense.DenseMatrix + options: + members: + - __init__ + +::: singd.structures.hierarchical.Hierarchical15_15Matrix + options: + members: + - __init__ + +# DIAGONAL + +::: singd.structures.diagonal.DiagonalMatrix + options: + members: + - __init__ + +::: singd.structures.blockdiagonal.Block30DiagonalMatrix + options: + members: + - __init__ + +# LOWER-TRIANGULAR + +::: singd.structures.triltoeplitz.TrilToeplitzMatrix + options: + members: + - __init__ + +::: singd.structures.trilbottomrightdiag.TrilBottomRightDiagonalMatrix + options: + members: + - __init__ + +::: singd.structures.triltopleftdiag.TrilTopLeftDiagonalMatrix + options: + members: + - __init__ + +# UPPER-TRIANGULAR + +::: singd.structures.triutoeplitz.TriuToeplitzMatrix + options: + members: + - __init__ + +::: singd.structures.triubottomrightdiag.TriuBottomRightDiagonalMatrix + options: + members: + - __init__ + +::: singd.structures.triutopleftdiag.TriuTopLeftDiagonalMatrix + options: + members: + - __init__ diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000..dd06594 --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,24 @@ +Here we provide a list of templates that can be used to create new structured +matrices. This list is meant **for internal purposes only**. It exists because +it is more convenient to read the rendered LaTeX code rather than the docstring +source. + +::: singd.structures.blockdiagonal.BlockDiagonalMatrixTemplate + options: + members: + - + +::: singd.structures.hierarchical.HierarchicalMatrixTemplate + options: + members: + - + +::: singd.structures.recursive.RecursiveBottomLeftMatrixTemplate + options: + members: + - + +::: singd.structures.recursive.RecursiveTopRightMatrixTemplate + options: + members: + - diff --git a/makefile b/makefile index 35779f4..c7e7fce 100644 --- a/makefile +++ b/makefile @@ -58,10 +58,9 @@ install-test: .PHONY: test test-light test: - @pytest -vx --run-optional-tests=expensive --cov=singd test - + @pytest -vx --run-optional-tests=expensive --cov=singd --doctest-modules test singd test-light: - @pytest -vx --cov=singd test + @pytest -vx --cov=singd --doctest-modules test singd .PHONY: install-lint diff --git a/mkdocs.yml b/mkdocs.yml index a10d83e..2754c19 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,6 +10,10 @@ nav: - Code Examples: generated/gallery - API Documentation: api.md - Developer Notes: develop.md + - Internal: + - Structures: structures.md + - Templates: templates.md + - Interface: interface.md theme: name: material features: @@ -34,7 +38,7 @@ plugins: options: show_root_heading: true show_source: true - show_bases: false + show_bases: true show_signature_annotations: true separate_signature: true docstring_section_style: list diff --git a/singd/structures/base.py b/singd/structures/base.py index f8a9412..a8bf4e3 100644 --- a/singd/structures/base.py +++ b/singd/structures/base.py @@ -25,22 +25,27 @@ class StructuredMatrix(ABC): a new structured matrix class with SINGD. The minimum amount of work to add a new structured matrix class requires - implementing the `to_dense`, `from_dense` methods. - The other operations will then use a naive implementation which internally + implementing the following methods: + + - `to_dense` + - `from_dense` + + All other operations will then use a naive implementation which internally re-constructs unstructured dense matrices. By default, these operations will trigger a warning which can be used to identify functions that can be implemented more efficiently using structure. - If you want to support data parallel training, you also have to implement - the `tensors_to_sync` method. + Note: + If you want to support data parallel training, you also have to implement + the `tensors_to_sync` method. Attributes: WARN_NAIVE: Warn the user if a method falls back to a naive implementation of this base class. This indicates a method that should be implemented to save memory and run time by considering the represented structure. - Default: ``True``. + Default: `True`. WARN_NAIVE_EXCEPTIONS: Set of methods that should not trigger a warning even - if ``WARN_NAIVE`` is ``True``. This can be used to silence warnings for + if `WARN_NAIVE` is `True`. This can be used to silence warnings for methods for which it is too complicated to leverage a specific structure and which should therefore call out to this class's implementation without performance warnings. @@ -55,9 +60,7 @@ def _tensors_to_sync(self) -> Tuple[Tensor, ...]: This is used to support distributed data parallel training. - # noqa: DAR202 - - Returns: + Returns: # noqa: DAR202 A tuple of tensors that need to be synchronized across devices. Raises: @@ -68,9 +71,7 @@ def _tensors_to_sync(self) -> Tuple[Tensor, ...]: def __matmul__( self, other: Union[StructuredMatrix, Tensor] ) -> Union[StructuredMatrix, Tensor]: - """Multiply onto a matrix (@ operator). - - (https://peps.python.org/pep-0465/) + """Multiply onto a matrix ([@ operator](https://peps.python.org/pep-0465/)). Args: other: Another matrix which will be multiplied onto. Can be represented @@ -99,14 +100,12 @@ def from_dense(cls, sym_mat: Tensor) -> StructuredMatrix: are non-zero. Warning: - We do not verify whether ``mat`` is symmetric internally. + We do not verify whether `mat` is symmetric internally. Args: sym_mat: A symmetric dense matrix which will be converted into a structured one. - # noqa: DAR202 - Returns: Structured matrix. @@ -119,8 +118,6 @@ def from_dense(cls, sym_mat: Tensor) -> StructuredMatrix: def to_dense(self) -> Tensor: """Return a dense tensor representing the structured matrix. - # noqa: DAR202 - Returns: A dense PyTorch tensor representing the matrix. @@ -165,7 +162,7 @@ def __sub__(self, other: StructuredMatrix) -> StructuredMatrix: return self + (other * (-1.0)) def rmatmat(self, mat: Tensor) -> Tensor: - """Multiply the structured matrix's transpose onto a matrix (``self.T @ mat``). + """Multiply the structured matrix's transpose onto a matrix (`self.T @ mat`). Args: mat: A dense matrix that will be multiplied onto. @@ -183,7 +180,7 @@ def _warn_naive_implementation(cls, fn_name: str): This suggests that a child class does not implement a specialized version that is usually more efficient. - You can turn off the warning by setting the ``WARN_NAIVE`` class attribute. + You can turn off the warning by setting the `WARN_NAIVE` class attribute. Args: fn_name: Name of the function whose naive version is being called. @@ -207,17 +204,17 @@ def all_reduce( parallel training. Args: - op: The reduction operation to perform (default: ``dist.ReduceOp.AVG``). - group: The process group to work on. If ``None``, the default process group + op: The reduction operation to perform (default: `dist.ReduceOp.AVG`). + group: The process group to work on. If `None`, the default process group will be used. - async_op: If ``True``, this function will return a - ``torch.distributed.Future`` object. + async_op: If `True`, this function will return a + `torch.distributed.Future` object. Otherwise, it will block until the reduction completes - (default: ``False``). + (default: `False`). Returns: - If ``async_op`` is ``True``, a (tuple of) ``torch.distributed.Future`` - object(s), else ``None``. + If `async_op` is `True`, a (tuple of) `torch.distributed.Future` + object(s), else `None`. """ handles = [] for tensor in self._tensors_to_sync: @@ -236,16 +233,16 @@ def all_reduce( ############################################################################### def from_inner(self, X: Union[Tensor, None] = None) -> StructuredMatrix: - """Extract the represented structure from ``self.T @ X @ X^T @ self``. + """Extract the represented structure from `self.T @ X @ X^T @ self`. - We can recycle terms by writing ``self.T @ X @ X^T @ self`` as ``S @ S^T`` - with ``S := self.T @ X``. + We can recycle terms by writing `self.T @ X @ X^T @ self` as `S @ S^T` + with `S := self.T @ X`. Args: - X: Optional arbitrary 2d tensor. If ``None``, ``X = I`` will be used. + X: Optional arbitrary 2d tensor. If `None`, `X = I` will be used. Returns: - The structured matrix extracted from ``self.T @ X @ X^T @ self``. + The structured matrix extracted from `self.T @ X @ X^T @ self`. """ self._warn_naive_implementation("from_inner") S_dense = self.to_dense().T if X is None else self.rmatmat(X) @@ -256,13 +253,13 @@ def from_inner(self, X: Union[Tensor, None] = None) -> StructuredMatrix: # integrating this interface into existing implementations of sparse IF-KFAC # easier, as they have access to the input/gradient covariance matrices. def from_inner2(self, XXT: Tensor) -> StructuredMatrix: - """Extract the represented structure from ``self.T @ XXT @ self``. + """Extract the represented structure from `self.T @ XXT @ self`. Args: XXT: 2d square symmetric matrix. Returns: - The structured matrix extracted from ``self.T @ XXT @ self``. + The structured matrix extracted from `self.T @ XXT @ self`. """ self._warn_naive_implementation("from_inner2") dense = self.to_dense() @@ -291,7 +288,7 @@ def diag_add_(self, value: float) -> StructuredMatrix: diag_add_(dense, value) # NOTE `self` is immutable, so we have to update its state with the following - # hack (otherwise, the call ``a.diag_add_(b)`` will not modify ``a``). See + # hack (otherwise, the call `a.diag_add_(b)` will not modify `a`). See # https://stackoverflow.com/a/37658673 and https://stackoverflow.com/q/1015592. new = self.from_dense(dense) self.__dict__.update(new.__dict__) @@ -302,9 +299,9 @@ def infinity_vector_norm(self) -> Tensor: The infinity vector norm is the absolute value of the largest entry. Note that this is different from the infinity matrix norm, compare - (here)[https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html] + [here](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html) and - (here)[https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html]. + [here](https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html). Note: This assumes that all tensors in `self._tensors_to_sync` contain diff --git a/singd/structures/blockdiagonal.py b/singd/structures/blockdiagonal.py index 074a303..4d5296c 100644 --- a/singd/structures/blockdiagonal.py +++ b/singd/structures/blockdiagonal.py @@ -1,4 +1,4 @@ -"""Block-diagonal dense matrix implemented in the ``StructuredMatrix`` interface.""" +"""Block-diagonal dense matrix implemented in the `StructuredMatrix` interface.""" from __future__ import annotations @@ -19,55 +19,87 @@ class BlockDiagonalMatrixTemplate(StructuredMatrix): - """Template for symmetric block-diagonal dense matrix. + r"""Template class for symmetric block-diagonal dense matrix. - `` - [[A₁, 0, ..., 0, 0], - [0, A₂, 0, ..., 0], - [0, 0, ..., 0, 0], - [0, ..., 0, A_N, 0], - [0, 0, 0, 0, B]] - `` + Note: + This is a template class. To define an actual class, inherit from this class, + then specify the `BLOCK_DIM` class attribute. See the example below. + + Block-diagonal matrices have the following structure: + + \( + \begin{pmatrix} + \mathbf{A}_1 & \mathbf{0} & \cdots & \cdots & \mathbf{0} \\ + \mathbf{0} & \mathbf{A}_2 & \mathbf{0} & \cdots & \mathbf{0} \\ + \vdots & \ddots & \ddots & \ddots & \vdots \\ + \mathbf{0} & \cdots & \mathbf{0} & \mathbf{A}_N & \mathbf{0} & \\ + \mathbf{0} & \cdots & \cdots & \mathbf{0} & \mathbf{B} + \end{pmatrix} + \in \mathbb{R}^{(N D + D') \times (N D + D')} + \) where - - ``A₁, ..., A_N`` are symmetric matrices of size ``block_dim`` - - ``B`` is a symmetric matrix of size ``last_dim`` if ``block_dim`` does not divide - the total matrix dimension. - Note: - This is a template class. To define an actual class, inherit from this class, - then specify the ``BLOCK_DIM`` class attribute. + - \(\mathbf{A}_n = \mathbf{A}_n^\top \in \mathbb{R}^{D \times D}\) are symmetric + matrices containing the diagonal blocks of block dimension \(D\). + - \(\mathbf{B} = \mathbf{B}^\top \in \mathbb{R}^{D' \times D'}\) is a symmetric + matrix containing the last block of dimension \(D' < D\), which can be empty + if \(D\) divides the matrix dimension. Attributes: - BLOCK_DIM: The dimension of a diagonal block. + BLOCK_DIM: The dimension of a diagonal block (\(D\)). + + Examples: + >>> from torch import ones + >>> + >>> class Block2DiagonalMatrix(BlockDiagonalMatrixTemplate): + ... '''Class to represent block-diagonal matrices with 2x2 blocks.''' + ... BLOCK_DIM = 2 + >>> + >>> # A block-diagonal matrix of total dimension 7x7 + >>> blocks, last = ones(3, 2, 2), 2 * ones(1, 1) + >>> mat = Block2DiagonalMatrix(blocks, last) + >>> mat.to_dense() + tensor([[1., 1., 0., 0., 0., 0., 0.], + [1., 1., 0., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0., 0.], + [0., 0., 1., 1., 0., 0., 0.], + [0., 0., 0., 0., 1., 1., 0.], + [0., 0., 0., 0., 1., 1., 0.], + [0., 0., 0., 0., 0., 0., 2.]]) """ BLOCK_DIM: int def __init__(self, blocks: Tensor, last: Tensor) -> None: - """Store the matrix internally. + r"""Store the matrix internally. Args: - blocks: The diagonal blocks ``A₁, A₂, ..., A_N`` of the matrix, supplied - as a tensor of shape ``[N, BLOCK_DIM, BLOCK_DIM]``. If there are no - blocks, has shape ``[0, BLOCK_DIM, BLOCK_DIM]``. - last: The last block if ``BLOCK_DIM`` which contains the remaining matrix - if ``BLOCK_DIM`` does not divide the matrix dimension. - Has shape ``[last_dim, last_dim]`` where ``last_dim`` may be zero. + blocks: The diagonal blocks + \(\{\mathbf{A}_n = \mathbf{A}_n^\top\}_{n = 1}^N\), + supplied as a tensor of shape `[N, BLOCK_DIM, BLOCK_DIM]`. If there are + no blocks, this argument has shape `[0, BLOCK_DIM, BLOCK_DIM]`. + last: The last block \(\mathbf{B} = \mathbf{B}^\top\) which contains the + remaining matrix if `BLOCK_DIM` does not divide the matrix dimension. + Has shape `[last_dim, last_dim]` where `last_dim` may be zero. + + Note: + For performance reasons, symmetry is not checked internally and must + be ensured by the caller. Raises: ValueError: If the passed tensors have incorrect shape. """ - if blocks.dim() != 3: + if blocks.ndim != 3: raise ValueError( - f"Diagonal blocks must be 3-dimensional, got {blocks.dim()}." + f"Diagonal blocks must be 3-dimensional, got {blocks.ndim}." ) if blocks.shape[1] != blocks.shape[2] != self.BLOCK_DIM: raise ValueError( f"Diagonal blocks must be square with dimension {self.BLOCK_DIM}," f" got {blocks.shape[1:]} instead." ) - if last.dim() != 2 or last.shape[0] != last.shape[1]: + if last.ndim != 2 or last.shape[0] != last.shape[1]: raise ValueError(f"Last block must be square, got {last.shape}.") if last.shape[0] >= self.BLOCK_DIM or last.shape[1] >= self.BLOCK_DIM: raise ValueError( @@ -94,10 +126,10 @@ def from_dense(cls, mat: Tensor) -> BlockDiagonalMatrixTemplate: Args: mat: A dense and symmetric square matrix which will be approximated by a - ``BlockDiagonalMatrixTemplate``. + `BlockDiagonalMatrixTemplate`. Returns: - ``BlockDiagonalMatrixTemplate`` approximating the passed matrix. + `BlockDiagonalMatrixTemplate` approximating the passed matrix. """ num_blocks = mat.shape[0] // cls.BLOCK_DIM @@ -142,20 +174,20 @@ def __matmul__( Args: other: A matrix which will be multiplied onto. Can be represented by a - PyTorch tensor or a ``BlockDiagonalMatrix``. + PyTorch tensor or a `BlockDiagonalMatrix`. Returns: Result of the multiplication. If a PyTorch tensor was passed as argument, the result will be a PyTorch tensor. If a block-diagonal matrix was passed, - the result will be returned as a ``BlockDiagonalMatrixTemplate``. + the result will be returned as a `BlockDiagonalMatrixTemplate`. Raises: - ValueError: If ``other``'s shape is incompatible. + ValueError: If `other`'s shape is incompatible. """ if isinstance(other, Tensor): num_blocks, last_dim = self._blocks.shape[0], self._last.shape[0] total_dim = num_blocks * self.BLOCK_DIM + last_dim - if other.shape[0] != total_dim or other.dim() != 2: + if other.shape[0] != total_dim or other.ndim != 2: raise ValueError( f"Expect matrix with {total_dim} rows. Got {other.shape}." ) @@ -216,7 +248,7 @@ def __mul__(self, other: float) -> BlockDiagonalMatrixTemplate: return self.__class__(other * self._blocks, other * self._last) def rmatmat(self, mat: Tensor) -> Tensor: - """Multiply ``mat`` with the transpose of the structured matrix. + """Multiply `mat` with the transpose of the structured matrix. Args: mat: A matrix which will be multiplied by the transpose of the represented @@ -232,24 +264,24 @@ def rmatmat(self, mat: Tensor) -> Tensor: ############################################################################### def from_inner(self, X: Union[Tensor, None] = None) -> BlockDiagonalMatrixTemplate: - """Represent the matrix block-diagonal of ``self.T @ X @ X^T @ self``. - - Let ``K := self``. We can first re-write ``K.T @ X @ X^T @ K`` into - ``S @ S.T`` where ``S = K.T @ X``. Next, note that ``S`` has block structure: - Write ``K := blockdiag(K₁, K₂, ...)`` and write ``X`` as a stack of matrices - ``X = vstack(X₁, X₂, ...)`` where ``Xᵢ`` is associated with the ``i``th diagonal - block. Then ``S = vstack( K₁.T @ X₁, K₂ @ X₂, ...) = vstack(S₁ S₂, ...)`` where - we have introduced ``Sᵢ = Kᵢ.T @ Xᵢ``. Consequently, ``S @ S.T`` consists of - blocks ``(i, j)`` with structure ``Sᵢ @ Sⱼ.T``. We are only interested in the + """Represent the matrix block-diagonal of `self.T @ X @ X^T @ self`. + + Let `K := self`. We can first re-write `K.T @ X @ X^T @ K` into + `S @ S.T` where `S = K.T @ X`. Next, note that `S` has block structure: + Write `K := blockdiag(K₁, K₂, ...)` and write `X` as a stack of matrices + `X = vstack(X₁, X₂, ...)` where `Xᵢ` is associated with the `i`th diagonal + block. Then `S = vstack( K₁.T @ X₁, K₂ @ X₂, ...) = vstack(S₁ S₂, ...)` where + we have introduced `Sᵢ = Kᵢ.T @ Xᵢ`. Consequently, `S @ S.T` consists of + blocks `(i, j)` with structure `Sᵢ @ Sⱼ.T`. We are only interested in the diagonal blocks. So we need to compute - ``Sᵢ @ Sᵢ.T = (Kᵢ.T @ Xᵢ) @ (Kᵢ.T @ Xᵢ).T`` for all ``i``. + `Sᵢ @ Sᵢ.T = (Kᵢ.T @ Xᵢ) @ (Kᵢ.T @ Xᵢ).T` for all `i`. Args: - X: Optional arbitrary 2d tensor. If ``None``, ``X = I`` will be used. + X: Optional arbitrary 2d tensor. If `None`, `X = I` will be used. Returns: - A ``DiagonalMatrix`` representing matrix block diagonal of - ``self.T @ X @ X^T @ self``. + A `DiagonalMatrix` representing matrix block diagonal of + `self.T @ X @ X^T @ self`. """ if X is None: S_blocks, S_last = self._blocks, self._last @@ -351,12 +383,22 @@ def eye( class Block30DiagonalMatrix(BlockDiagonalMatrixTemplate): - """Block-diagonal matrix with blocks of size 30.""" + """Block-diagonal matrix with blocks of size 30. + + Note: + See the template class `BlockDiagonalMatrixTemplate` for a mathematical + description. + """ BLOCK_DIM = 30 class Block3DiagonalMatrix(BlockDiagonalMatrixTemplate): - """Block-diagonal matrix with blocks of size 3.""" + """Block-diagonal matrix with blocks of size 3. + + Note: + See the template class `BlockDiagonalMatrixTemplate` for a mathematical + description. + """ BLOCK_DIM = 3 diff --git a/singd/structures/dense.py b/singd/structures/dense.py index 3259c3b..550873c 100644 --- a/singd/structures/dense.py +++ b/singd/structures/dense.py @@ -1,4 +1,4 @@ -"""Dense matrix implemented in the ``StructuredMatrix`` interface.""" +"""Dense matrix implemented in the `StructuredMatrix` interface.""" from __future__ import annotations @@ -10,16 +10,31 @@ class DenseMatrix(StructuredMatrix): - """Unstructured dense matrix implemented in the ``StructuredMatrix`` interface.""" + r"""Unstructured dense matrix implemented in the `StructuredMatrix` interface. + + \[ + \begin{pmatrix} + \mathbf{A} + \end{pmatrix} + \quad \text{with} \quad + \mathbf{A} = \mathbf{A}^\top\,. + \] + + """ WARN_NAIVE: bool = False # Fall-back to naive base class implementations OK def __init__(self, mat: Tensor) -> None: - """Store the dense matrix internally. + r"""Store the dense matrix internally. + + Note: + For performance reasons, symmetry is not checked internally and must + be ensured by the caller. Args: - mat: A dense square matrix. + mat: A symmetric matrix representing \(\mathbf{A}\). """ + self._check_square(mat) self._mat = mat @property @@ -38,11 +53,10 @@ def from_dense(cls, sym_mat: Tensor) -> DenseMatrix: """Construct from a PyTorch tensor. Args: - sym_mat: A dense symmetric matrix which will be represented as - ``DenseMatrix``. + sym_mat: A dense symmetric matrix that will be represented as `DenseMatrix`. Returns: - ``DenseMatrix`` representing the passed matrix. + `DenseMatrix` representing the passed matrix. """ return cls(sym_mat) diff --git a/singd/structures/diagonal.py b/singd/structures/diagonal.py index ac16b53..abb3956 100644 --- a/singd/structures/diagonal.py +++ b/singd/structures/diagonal.py @@ -11,13 +11,36 @@ class DiagonalMatrix(StructuredMatrix): - """Diagonal matrix implemented in the ``StructuredMatrix`` interface.""" + r"""Diagonal matrix implemented in the ``StructuredMatrix`` interface. + + A diagonal matrix is defined as + + \( + \begin{pmatrix} + d_1 & 0 & \cdots & 0 \\ + 0 & d_2 & \ddots & \vdots \\ + \vdots & \ddots & \ddots & 0 \\ + 0 & \cdots & \ddots & d_K \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \quad + \text{with} + \quad + \mathbf{d} + := + \begin{pmatrix} + d_1 \\ + d_2 \\ + \vdots \\ + d_K \\ + \end{pmatrix} \in \mathbb{R}^K + \) + """ def __init__(self, mat_diag: Tensor) -> None: - """Store the dense matrix internally. + r"""Store the dense matrix internally. Args: - mat_diag: A 1d tensor representing the matrix diagonal. + mat_diag: A 1d tensor representing the matrix diagonal \(\mathbf{d}\). """ self._mat_diag = mat_diag diff --git a/singd/structures/hierarchical.py b/singd/structures/hierarchical.py index 8172e86..d7bbdd2 100644 --- a/singd/structures/hierarchical.py +++ b/singd/structures/hierarchical.py @@ -19,58 +19,95 @@ class HierarchicalMatrixTemplate(StructuredMatrix): - """Template for hierarchical matrices. + r"""Template class for creating hierarchical matrices. - ``[[A, B ], - [ 0, C, 0], - [ 0, D, E],]`` + Note: + This is a template class. To define an actual class, inherit from this class, + then specify the `MAX_K1` and `MAX_K2` class attributes. See the example below. - where (denoting ``K`` the matrix dimension) + Hierarchical matrices have the following structure: - - ``A`` is dense square and has shape ``[K1, K1]`` - - ``B`` is dense rectangular of shape ``[K1, K - K1]`` - - ``C`` is a diagonal matrix of shape ``[K - K1 - K2, K - K1 - K2]`` - - ``D`` is dense rectangular of shape ``[K2, K - K1 - K2]`` - - ``E`` is dense square and has shape ``[K2, K2]`` + \( + \begin{pmatrix} + \mathbf{A} & \mathbf{B}_1 & \mathbf{B}_2 \\ + \mathbf{0} & \mathbf{C} & \mathbf{0} \\ + \mathbf{0} & \mathbf{D} & \mathbf{E} \\ + \end{pmatrix} + \in \mathbb{R}^{K \times K} + \) - Note: - This is a template class. To define an actual class, inherit from this class, - then specify the ``MAX_K1`` and ``MAX_K2`` class attributes. + where (denoting + \(\mathbf{B} := \begin{pmatrix}\mathbf{B}_1 & \mathbf{B}_2\end{pmatrix}\)) + + - \(\mathbf{A} \in \mathbb{R}^{K_1 \times K_1}\) is dense symmetric + - \(\mathbf{B} \in \mathbb{R}^{K_1 \times (K - K_1)}\) is dense rectangular + - \(\mathbf{C} \in \mathbb{R}^{(K - K_2 - K_1) \times (K - K_2 - K_1)}\) is diagonal + - \(\mathbf{D} \in \mathbb{R}^{K_2 \times (K - K_2 - K_1)}\) is dense rectangular + - \(\mathbf{E} \in \mathbb{R}^{K_2 \times K_2}\) is dense symmetric - Given specific values for ``K1, K2``, if the matrix to be represented is not + For fixed values of \(K_1, K_2\), if the matrix to be represented is not big enough to fit all structures, we use the following prioritization: - 1. If ``K <= K1``, start by filling ``A``. - 2. If ``K1 < K <= K1+K2``, fill ``A`` and start filling ``B`` and ``E``. - 3. If ``K1+K2 < K``, use all structures. + 1. If \(K \le K_1\), start by filling \(\mathbf{A}\). + 2. If \(K_1 < K \le K_1+K_2\), fill \(\mathbf{A}\) and start filling \(\mathbf{B}\) + and \(\mathbf{E}\). + 3. If \(K_1+K_2 < K\), use all structures. Attributes: - MAX_K1: Maximum dimension of the top left. - MAX_K2: Maximum dimension of the bottom right block. + MAX_K1: Maximum dimension \(K_1\) of the top left block \(\mathbf{A}\). + MAX_K2: Maximum dimension \(K_2\) of the bottom right block \(\mathbf{E}\). + + Examples: + >>> from torch import ones + >>> + >>> class Hierarchical2_3Matrix(HierarchicalMatrixTemplate): + ... '''Hierarchical matrix with 2x2 top left and 3x3 bottom right block.''' + ... MAX_K1 = 2 + ... MAX_K2 = 3 + >>> + >>> # A hierarchical matrix with total dimension K=7 + >>> A, C, E = ones(2, 2), 3 * ones(2), 5 * ones(3, 3) + >>> B, D = 2 * ones(2, 5), 4 * ones(3, 2) + >>> mat = Hierarchical2_3Matrix(A, B, C, D, E) + >>> mat.to_dense() + tensor([[1., 1., 2., 2., 2., 2., 2.], + [1., 1., 2., 2., 2., 2., 2.], + [0., 0., 3., 0., 0., 0., 0.], + [0., 0., 0., 3., 0., 0., 0.], + [0., 0., 4., 4., 5., 5., 5.], + [0., 0., 4., 4., 5., 5., 5.], + [0., 0., 4., 4., 5., 5., 5.]]) """ MAX_K1: int MAX_K2: int def __init__(self, A: Tensor, B: Tensor, C: Tensor, D: Tensor, E: Tensor): - """Store the structural components internally. - - Please read the class docstring for more information. + r"""Store the structural components internally. Args: - A: Dense square matrix of shape ``[K1, K1]`` or smaller. - B: Dense rectangular matrix of shape ``[K1, K - K1]``. - C: Diagonal of shape ``[K - K1 - K2]``. - D: Dense rectangular matrix of shape ``[K2, K - K1 - K2]``. - E: Dense square matrix of shape ``[K2, K2]`` or smaller. + A: Dense symmetric matrix of shape `[K1, K1]` or smaller representing + \(\mathbf{A}\). + B: Dense rectangular matrix of shape `[K1, K - K1]` representing + \(\mathbf{B}\). + C: Vector of shape `[K - K1 - K2]` representing the diagonal of + \(\mathbf{C}\). + D: Dense rectangular matrix of shape `[K2, K - K1 - K2]` representing + \(\mathbf{D}\). + E: Dense symmetric matrix of shape `[K2, K2]` or smaller representing + \(\mathbf{E}\). + + Note: + For performance reasons, symmetry is not checked internally and must + be ensured by the caller. Raises: ValueError: If the shapes of the arguments are invalid. """ - if A.dim() != 2 or B.dim() != 2 or C.dim() != 1 or D.dim() != 2 or E.dim() != 2: + if A.ndim != 2 or B.ndim != 2 or C.ndim != 1 or D.ndim != 2 or E.ndim != 2: raise ValueError( "Invalid tensor dimensions. Expected 2, 2, 1, 2, 2." - + f" Got {A.dim()}, {B.dim()}, {C.dim()}, {D.dim()}, {E.dim()}." + + f" Got {A.ndim}, {B.ndim}, {C.ndim}, {D.ndim}, {E.ndim}." ) self._check_square(A, name="A") self._check_square(E, name="E") @@ -118,10 +155,10 @@ def from_dense(cls, sym_mat: Tensor) -> HierarchicalMatrixTemplate: Args: sym_mat: A dense symmetric matrix which will be represented as - ``Hierarchical``. + `Hierarchical`. Returns: - ``HierarchicalMatrix`` representing the passed matrix. + `HierarchicalMatrix` representing the passed matrix. """ cls._check_square(sym_mat) dim = sym_mat.shape[0] @@ -164,7 +201,7 @@ def __matmul__( Returns: Result of the multiplication. If a PyTorch tensor was passed as argument, the result will be a PyTorch tensor. If a hierarchial matrix was passed, - the result will be returned as a ``HierarchicalMatrixTemplate``. + the result will be returned as a `HierarchicalMatrixTemplate`. """ # parts of B that share columns with C, E B_C, B_E = self.B.split([self.diag_dim, self.K2], dim=1) @@ -243,7 +280,7 @@ def __mul__(self, other: float) -> HierarchicalMatrixTemplate: ) def rmatmat(self, mat: Tensor) -> Tensor: - """Multiply ``mat`` with the transpose of the structured matrix. + """Multiply `mat` with the transpose of the structured matrix. Args: mat: A matrix which will be multiplied by the transpose of the represented @@ -278,14 +315,14 @@ def rmatmat(self, mat: Tensor) -> Tensor: ############################################################################### def from_inner(self, X: Union[Tensor, None] = None) -> HierarchicalMatrixTemplate: - """Represent the hierarchical matrix of ``self.T @ X @ X^T @ self``. + """Represent the hierarchical matrix of `self.T @ X @ X^T @ self`. Args: - X: Optional arbitrary 2d tensor. If ``None``, ``X = I`` will be used. + X: Optional arbitrary 2d tensor. If `None`, `X = I` will be used. Returns: - A ``HierarchicalMatrix`` representing hierarchical matrix of - ``self.T @ X @ X^T @ self``. + A `HierarchicalMatrix` representing hierarchical matrix of + `self.T @ X @ X^T @ self`. """ if X is None: A_new = supported_matmul(self.A.T, self.A) @@ -397,13 +434,13 @@ def eye( @classmethod def _compute_block_dims(cls, dim: int) -> Tuple[int, int, int]: - """Compute the dimensions of ``A, C, E``. + """Compute the dimensions of `A, C, E`. Args: dim: Total dimension of the (square) matrix. Returns: - A tuple of the form ``(K1, diag_dim, K2)``. + A tuple of the form `(K1, diag_dim, K2)`. """ if dim <= cls.MAX_K1: K1, diag_dim, K2 = dim, 0, 0 @@ -415,14 +452,24 @@ def _compute_block_dims(cls, dim: int) -> Tuple[int, int, int]: class Hierarchical15_15Matrix(HierarchicalMatrixTemplate): - """Hierarchical matrix with ``K1=15`` and ``K2=15``.""" + """Hierarchical matrix with `K1=15` and `K2=15`. + + Note: + See the template class `HierarchicalMatrixTemplate` for a mathematical + description. + """ MAX_K1 = 15 MAX_K2 = 15 class Hierarchical3_2Matrix(HierarchicalMatrixTemplate): - """Hierarchical matrix with ``K1=3`` and ``K2=2``.""" + """Hierarchical matrix with `K1=3` and `K2=2`. + + Note: + See the template class `HierarchicalMatrixTemplate` for a mathematical + description. + """ MAX_K1 = 3 MAX_K2 = 2 diff --git a/singd/structures/recursive.py b/singd/structures/recursive.py index 4d2b83e..b01004b 100644 --- a/singd/structures/recursive.py +++ b/singd/structures/recursive.py @@ -10,7 +10,12 @@ class RecursiveTopRightMatrixTemplate(StructuredMatrix): - r"""Template to define recursive structured matrices with top right dense block. + r"""Template to define recursively structured matrices with top right dense block. + + Note: + This is a template class. To define an actual class, inherit from this class, + then specify the attributes `MAX_DIMS`, `CLS_A`, and `CLS_C`. See the example + below. This matrix is defined by @@ -24,33 +29,66 @@ class RecursiveTopRightMatrixTemplate(StructuredMatrix): - \(\mathbf{A}, \mathbf{C}\) are structured matrices (which can be recursive). - \(\mathbf{B}\) is a dense rectangular matrix. - Note: - This is a template class. To define an actual class, inherit from this class, - then specify the attributes `MAX_DIMS`, `CLS_A`, and `CLS_C`. Attributes: - MAX_DIMS: A tuple that contains integers and `float('inf')` which indicate - the maximum dimension of `A` and `C`. For example, `(10, float('inf'))` - means that `A` will be used for dimensions up to 10, and `C` will be used - in addition for larger dimensions. - CLS_A: Structured matrix class used for the top left block. - CLS_C: Structured matrix class used for the the bottom right block. + MAX_DIMS: A tuple that contains an integer and a `float('inf')` which indicate + the maximum dimensions of \(\mathbf{A}\) and \(\mathbf{C}\). For example, + `(10, float('inf'))` means that only \(\mathbf{A}\) will be used for + dimensions up to 10, and \(\mathbf{C}\) will be used in addition for larger + dimensions. + CLS_A: Structured matrix class used for the top left block \(\mathbf{A}\). + CLS_C: Structured matrix class used for the the bottom right block + \(\mathbf{B}\). + + Examples: + >>> from torch import ones + >>> from singd.structures.dense import DenseMatrix + >>> from singd.structures.diagonal import DiagonalMatrix + >>> + >>> class Dense3DiagonalTopRightMatrix(RecursiveTopRightMatrixTemplate): + ... '''Structured matrix with 3 dense rows upper and lower diagonal part.''' + ... MAX_DIMS = (3, float('inf')) + ... CLS_A = DenseMatrix + ... CLS_C = DiagonalMatrix + >>> + >>> # A 5x5 matrix with 3 dense rows in the upper and lower diagonal part + >>> A = DenseMatrix(ones(3, 3)) + >>> B = 2 * ones(3, 2) + >>> C = DiagonalMatrix(3 * ones(2)) + >>> mat = Dense3DiagonalTopRightMatrix(A, B, C) + >>> mat.to_dense() + tensor([[1., 1., 1., 2., 2.], + [1., 1., 1., 2., 2.], + [1., 1., 1., 2., 2.], + [0., 0., 0., 3., 0.], + [0., 0., 0., 0., 3.]]) """ MAX_DIMS: Tuple[Union[int, float], Union[int, float]] CLS_A: Type[StructuredMatrix] CLS_C: Type[StructuredMatrix] def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): - """Store the matrix internally. + r"""Store the matrix internally. Args: - A: Top left block. - B: Top right block. - C: Bottom right block. + A: Structured matrix representing the top left block \(\mathbf{A}\). + B: Rectangular tensor representing the top right block \(\mathbf{B}\). + C: Structured matrix representing the bottom right block \(\mathbf{C}\). + + Note: + For performance reasons, symmetry is not checked internally and must + be ensured by the caller. Raises: - ValueError: If the dimensions of the blocks do not match. + ValueError: If the dimensions of the blocks do not match or the + structured matrices are of wrong type. """ + if not isinstance(A, self.CLS_A) or not isinstance(C, self.CLS_C): + raise ValueError( + f"Matrices A and C must be of type {self.CLS_A} and " + f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}." + ) + # TODO Add a `dim` property to make this cheaper dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0] if B.shape != (dim_A, dim_C): @@ -114,7 +152,12 @@ def to_dense(self) -> Tensor: class RecursiveBottomLeftMatrixTemplate(StructuredMatrix): - r"""Template to define recursive structured matrices with bottom left dense block. + r"""Template to define recursively structured matrices with bottom left dense block. + + Note: + This is a template class. To define an actual class, inherit from this class, + then specify the attributes `MAX_DIMS`, `CLS_A`, and `CLS_C`. See the example + below. This matrix is defined by @@ -128,33 +171,66 @@ class RecursiveBottomLeftMatrixTemplate(StructuredMatrix): - \(\mathbf{A}, \mathbf{C}\) are structured matrices (which can be recursive). - \(\mathbf{B}\) is a dense rectangular matrix. - Note: - This is a template class. To define an actual class, inherit from this class, - then specify the attributes `MAX_DIMS`, `CLS_A`, and `CLS_C`. Attributes: - MAX_DIMS: A tuple that contains integers and `float('inf')` which indicate - the maximum dimension of `A` and `C`. For example, `(10, float('inf'))` - means that `A` will be used for dimensions up to 10, and `C` will be used - in addition for larger dimensions. - CLS_A: Structured matrix class used for the top left block. - CLS_C: Structured matrix class used for the the bottom right block. + MAX_DIMS: A tuple that contains an integer and a `float('inf')` which indicate + the maximum dimensions of \(\mathbf{A}\) and \(\mathbf{C}\). For example, + `(10, float('inf'))` means that only \(\mathbf{A}\) will be used for + dimensions up to 10, and \(\mathbf{C}\) will be used in addition for larger + dimensions. + CLS_A: Structured matrix class used for the top left block \(\mathbf{A}\). + CLS_C: Structured matrix class used for the the bottom right block + \(\mathbf{C}\). + + Examples: + >>> from torch import ones + >>> from singd.structures.dense import DenseMatrix + >>> from singd.structures.diagonal import DiagonalMatrix + >>> + >>> class Dense3DiagonalBottomLeftMatrix(RecursiveBottomLeftMatrixTemplate): + ... '''Structured matrix with 3 left columns and right diagonal part.''' + ... MAX_DIMS = (3, float('inf')) + ... CLS_A = DenseMatrix + ... CLS_C = DiagonalMatrix + >>> + >>> # A 5x5 matrix with 3 left columns and right diagonal part + >>> A = DenseMatrix(ones(3, 3)) + >>> B = 2 * ones(2, 3) + >>> C = DiagonalMatrix(3 * ones(2)) + >>> mat = Dense3DiagonalBottomLeftMatrix(A, B, C) + >>> mat.to_dense() + tensor([[1., 1., 1., 0., 0.], + [1., 1., 1., 0., 0.], + [1., 1., 1., 0., 0.], + [2., 2., 2., 3., 0.], + [2., 2., 2., 0., 3.]]) """ MAX_DIMS: Tuple[Union[int, float], Union[int, float]] CLS_A: Type[StructuredMatrix] CLS_C: Type[StructuredMatrix] def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): - """Store the matrix internally. + r"""Store the matrix internally. Args: - A: Top left block. - B: Bottom left block. - C: Bottom right block. + A: Structured matrix representing the top left block \(\mathbf{A}\). + B: Rectangular tensor representing the bottom left block \(\mathbf{B}\). + C: Structured matrix representing the bottom right block \(\mathbf{C}\). + + Note: + For performance reasons, symmetry is not checked internally and must + be ensured by the caller. Raises: - ValueError: If the dimensions of the blocks do not match. + ValueError: If the dimensions of the blocks do not match or the structured + matrices are of wrong type. """ + if not isinstance(A, self.CLS_A) or not isinstance(C, self.CLS_C): + raise ValueError( + f"Matrices A and C must be of type {self.CLS_A} and " + f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}." + ) + # TODO Add a `dim` property to make this cheaper dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0] if B.shape != (dim_C, dim_A): diff --git a/singd/structures/trilbottomrightdiag.py b/singd/structures/trilbottomrightdiag.py index 42ffd8c..a645501 100644 --- a/singd/structures/trilbottomrightdiag.py +++ b/singd/structures/trilbottomrightdiag.py @@ -6,17 +6,24 @@ class TrilBottomRightDiagonalMatrix(RecursiveBottomLeftMatrixTemplate): - """Sparse lower-triangular matrix with bottom right diagonal. + r"""Sparse lower-triangular matrix with bottom right diagonal. - `` - [[c1, 0], - [[c2, D]] - `` + This matrix is defined as follows: + + \( + \begin{pmatrix} + a & \mathbf{0} \\ + \mathbf{b} & \mathbf{C} \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \) where - - ``c1`` is a scalar, - - ``c2`` is a row vector, and - - ``D`` is a diagonal matrix. + + - \(a \in \mathbb{R}\) is a scalar, represented by a `DenseMatrix` + - \(\mathbf{b} \in \mathbb{R}^{K-1}\) is a row vector, represented as PyTorch + `Tensor`, and + - \(\mathbf{C} \in \mathbb{R}^{(K-1)\times (K-1)}\) is a diagonal matrix represented + as a `DiagonalMatrix`. """ MAX_DIMS = (1, float("inf")) diff --git a/singd/structures/triltoeplitz.py b/singd/structures/triltoeplitz.py index 650fd1b..ffaabe2 100644 --- a/singd/structures/triltoeplitz.py +++ b/singd/structures/triltoeplitz.py @@ -1,4 +1,4 @@ -"""Toeplitz matrix implemented in the ``StructuredMatrix`` interface.""" +"""Toeplitz matrix implemented in the `StructuredMatrix` interface.""" from __future__ import annotations @@ -18,10 +18,29 @@ class TrilToeplitzMatrix(StructuredMatrix): - """Lower-triangular Toeplitz-structured matrix in ``StructuredMatrix`` interface. - - We follow the representation of such matrices using the SciPy terminology, see - https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html + r"""Class for lower-triangular Toeplitz-structured matrices. + + A lower-triangular Toeplitz matrix is defined by: + + \( + \begin{pmatrix} + d_1 & 0 & \cdots & 0 \\ + d_2 & d_1 & \ddots & \vdots \\ + \vdots & \ddots & \ddots & 0 \\ + d_K & \cdots & d_2 & d_1 \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \quad + \text{with} + \quad + \mathbf{d} + := + \begin{pmatrix} + d_1 \\ + d_2 \\ + \vdots \\ + d_K \\ + \end{pmatrix} \in \mathbb{R}^K\,. + \) """ WARN_NAIVE_EXCEPTIONS = { # hard to leverage structure for efficient implementation @@ -30,12 +49,11 @@ class TrilToeplitzMatrix(StructuredMatrix): } def __init__(self, diag_consts: Tensor) -> None: - """Store the lower-triangular Toeplitz matrix internally. + r"""Store the lower-triangular Toeplitz matrix internally. Args: - diag_consts: A vector containing the constants of all diagonals, i.e. - the first entry corresponds to the constant on the diagonal, the - second entry to the constant on the lower first off-diagonal, etc. + diag_consts: The vector \(\mathbf{d}\) containing the constants of all + lower diagonals, starting with the value on the main diagonal. """ self._mat_column = diag_consts @@ -56,10 +74,10 @@ def from_dense(cls, mat: Tensor) -> TrilToeplitzMatrix: Args: mat: A dense and symmetric square matrix which will be approximated by a - ``TrilToeplitzMatrix``. + `TrilToeplitzMatrix`. Returns: - ``TrilToeplitzMatrix`` approximating the passed matrix. + `TrilToeplitzMatrix` approximating the passed matrix. """ assert mat.shape[0] == mat.shape[1] traces = all_traces(mat) @@ -120,12 +138,12 @@ def __matmul__( Args: other: A matrix which will be multiplied onto. Can be represented by a - PyTorch tensor or another ``TrilToeplitzMatrix``. + PyTorch tensor or another `TrilToeplitzMatrix`. Returns: Result of the multiplication. If a PyTorch tensor was passed as argument, the result will be a PyTorch tensor. If a triu Toeplitz matrix was passed, - the result will be returned as a ``TrilToeplitzMatrix``. + the result will be returned as a `TrilToeplitzMatrix`. """ col = self._mat_column dim = col.shape[0] @@ -144,14 +162,14 @@ def __matmul__( return TrilToeplitzMatrix(mat_column) def rmatmat(self, mat: Tensor) -> Tensor: - """Multiply ``mat`` with the transpose of the structured matrix. + """Multiply `mat` with the transpose of the structured matrix. Args: mat: A matrix which will be multiplied by the transpose of the represented diagonal matrix. Returns: - The result of ``self.T @ mat``. + The result of `self.T @ mat`. """ col = self._mat_column dim = col.shape[0] diff --git a/singd/structures/triltopleftdiag.py b/singd/structures/triltopleftdiag.py index b1c4b6f..492440e 100644 --- a/singd/structures/triltopleftdiag.py +++ b/singd/structures/triltopleftdiag.py @@ -8,17 +8,22 @@ class TrilTopLeftDiagonalMatrix(RecursiveBottomLeftMatrixTemplate): r"""Sparse lower-triangular matrix with top left diagonal entries. + This matrix is defined as follows: + \( \begin{pmatrix} - \mathbf{D} & \mathbf{0} \\ - r_1 & \mathbf{r}_2 - \end{pmatrix} + \mathbf{A} & \mathbf{0} \\ + \mathbf{b} & c \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} \) where - - \(\mathbf{D}\) is a diagonal matrix, - - \(r_1\) is a scalar, and - - \(\mathbf{r}_2\) is a row vector. + + - \(\mathbf{A} \in \mathbb{R}^{(K-1)\times (K-1)}\) is a diagonal matrix represented + as a `DiagonalMatrix`. + - \(\mathbf{b} \in \mathbb{R}^{K-1}\) is a row vector, represented as PyTorch + `Tensor`, and + - \(c \in \mathbb{R}\) is a scalar, represented by a `DenseMatrix`. """ MAX_DIMS = (float("inf"), 1) diff --git a/singd/structures/triubottomrightdiag.py b/singd/structures/triubottomrightdiag.py index bd58680..cef50a6 100644 --- a/singd/structures/triubottomrightdiag.py +++ b/singd/structures/triubottomrightdiag.py @@ -6,17 +6,24 @@ class TriuBottomRightDiagonalMatrix(RecursiveTopRightMatrixTemplate): - """Sparse upper-triangular matrix with bottom right diagonal entries. + r"""Sparse upper-triangular matrix with bottom right diagonal entries. - `` - [[r1, r2], - [[0, D]] - `` + This matrix is defined as follows: + + \( + \begin{pmatrix} + a & \mathbf{b} \\ + \mathbf{0} & \mathbf{C} \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \) where - - ``r1`` is a scalar, - - ``r2`` is a row vector, and - - ``D`` is a diagonal matrix. + + - \(a \in \mathbb{R}\) is a scalar, represented by a `DenseMatrix` + - \(\mathbf{b} \in \mathbb{R}^{K-1}\) is a column vector, represented as PyTorch + `Tensor`, and + - \(\mathbf{C} \in \mathbb{R}^{(K-1)\times (K-1)}\) is a diagonal matrix represented + as a `DiagonalMatrix`. """ MAX_DIMS = (1, float("inf")) diff --git a/singd/structures/triutoeplitz.py b/singd/structures/triutoeplitz.py index a657fe8..6a3d860 100644 --- a/singd/structures/triutoeplitz.py +++ b/singd/structures/triutoeplitz.py @@ -18,10 +18,29 @@ class TriuToeplitzMatrix(StructuredMatrix): - """Upper-triangular Toeplitz-structured matrix. - - We follow the representation of such matrices using the SciPy terminology, see - https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html + r"""Class for upper-triangular Toeplitz-structured matrices. + + An upper-triangular Toeplitz matrix is defined by: + + \( + \begin{pmatrix} + d_1 & d_2 & \cdots & d_K \\ + 0 & d_1 & \ddots & \vdots \\ + \vdots & \ddots & \ddots & d_2 \\ + 0 & \cdots & 0 & d_1 \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \quad + \text{with} + \quad + \mathbf{d} + := + \begin{pmatrix} + d_1 \\ + d_2 \\ + \vdots \\ + d_K \\ + \end{pmatrix} \in \mathbb{R}^K\,. + \) """ WARN_NAIVE_EXCEPTIONS = { # hard to leverage structure for efficient implementation @@ -30,12 +49,11 @@ class TriuToeplitzMatrix(StructuredMatrix): } def __init__(self, diag_consts: Tensor) -> None: - """Store the upper-triangular Toeplitz matrix internally. + r"""Store the upper-triangular Toeplitz matrix internally. Args: - diag_consts: A vector containing the constants of all diagonals, i.e. - the first entry corresponds to the constant on the diagonal, the - second entry to the constant on the upper first off-diagonal, etc. + diag_consts: A vector \(\mathbf{d}\) containing the constants of all + upper diagonals, starting with the main diagonal. """ self._mat_row = diag_consts diff --git a/singd/structures/triutopleftdiag.py b/singd/structures/triutopleftdiag.py index 2967603..78bfc93 100644 --- a/singd/structures/triutopleftdiag.py +++ b/singd/structures/triutopleftdiag.py @@ -6,17 +6,24 @@ class TriuTopLeftDiagonalMatrix(RecursiveTopRightMatrixTemplate): - """Sparse upper-triangular matrix with top left diagonal entries. + r"""Sparse upper-triangular matrix with top left diagonal entries. - `` - [[D, c1], - [[0, c2]] - `` + This matrix is defined as follows: + + \( + \begin{pmatrix} + \mathbf{A} & \mathbf{b} \\ + \mathbf{0} & c \\ + \end{pmatrix} \in \mathbb{R}^{K \times K} + \) where - - ``D`` is a diagonal matrix, - - ``c1`` is a row vector, and - - ``c2`` is a scalar. + + - \(\mathbf{A} \in \mathbb{R}^{(K-1)\times (K-1)}\) is a diagonal matrix represented + as a `DiagonalMatrix`. + - \(\mathbf{b} \in \mathbb{R}^{K-1}\) is a column vector, represented as PyTorch + `Tensor`, and + - \(c \in \mathbb{R}\) is a scalar, represented by a `DenseMatrix`. """ MAX_DIMS = (float("inf"), 1) diff --git a/singd/structures/utils.py b/singd/structures/utils.py index 6cf109c..cb18600 100644 --- a/singd/structures/utils.py +++ b/singd/structures/utils.py @@ -253,7 +253,7 @@ def diag_add_(mat: Tensor, value: float) -> Tensor: Returns: The input matrix with the value added to the main diagonal. """ - if mat.dim() != 2 or mat.shape[0] != mat.shape[1]: + if mat.ndim != 2 or mat.shape[0] != mat.shape[1]: raise ValueError(f"Expected square matrix, but got {mat.shape}.") dim = mat.shape[0]