Skip to content

Commit

Permalink
[ADD] Frobenius norm in structured matrix interface
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 30, 2024
1 parent dcf1cbc commit ec80317
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
10 changes: 10 additions & 0 deletions singd/structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist
from torch import Tensor, zeros
from torch.linalg import matrix_norm

from singd.structures.utils import diag_add_, supported_eye, supported_matmul

Expand Down Expand Up @@ -343,6 +344,15 @@ def infinity_vector_norm(self) -> Tensor:
# NOTE `.max` can only be called on tensors with non-zero shape
return max(t.abs().max() for _, t in self.named_tensors() if t.numel() > 0)

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.
Returns:
The Frobenius norm of the represented matrix.
"""
self._warn_naive_implementation("trace")
return matrix_norm(self.to_dense())

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
18 changes: 17 additions & 1 deletion test/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matplotlib import pyplot as plt
from pytest import mark
from torch import Tensor, device, manual_seed, rand, zeros
from torch.linalg import vector_norm
from torch.linalg import matrix_norm, vector_norm

from singd.structures.base import StructuredMatrix
from singd.structures.utils import is_half_precision, supported_eye, supported_matmul
Expand Down Expand Up @@ -582,6 +582,22 @@ def test_infinity_vector_norm(self, dev: device, dtype: torch.dtype):
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.infinity_vector_norm())

@mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)
@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_frobenius_norm(self, dev: device, dtype: torch.dtype):
"""Test Frobenius norm of a structured matrix.
Args:
dev: The device on which to run the test.
dtype: The data type of the matrices.
"""
for dim in self.DIMS:
manual_seed(0)
sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype))
truth = matrix_norm(self.project(sym_mat))
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.frobenius_norm())

@mark.expensive
def test_visual(self):
"""Create pictures and animations of the structure.
Expand Down

0 comments on commit ec80317

Please sign in to comment.