diff --git a/singd/structures/base.py b/singd/structures/base.py index 9ff6338..8d04bad 100644 --- a/singd/structures/base.py +++ b/singd/structures/base.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist from torch import Tensor, zeros +from torch.linalg import matrix_norm from singd.structures.utils import diag_add_, supported_eye, supported_matmul @@ -343,6 +344,15 @@ def infinity_vector_norm(self) -> Tensor: # NOTE `.max` can only be called on tensors with non-zero shape return max(t.abs().max() for _, t in self.named_tensors() if t.numel() > 0) + def frobenius_norm(self) -> Tensor: + """Compute the Frobenius norm of the represented matrix. + + Returns: + The Frobenius norm of the represented matrix. + """ + self._warn_naive_implementation("trace") + return matrix_norm(self.to_dense()) + ############################################################################### # Special initialization operations # ############################################################################### diff --git a/test/structures/utils.py b/test/structures/utils.py index f49d6ee..17e0cf6 100644 --- a/test/structures/utils.py +++ b/test/structures/utils.py @@ -11,7 +11,7 @@ from matplotlib import pyplot as plt from pytest import mark from torch import Tensor, device, manual_seed, rand, zeros -from torch.linalg import vector_norm +from torch.linalg import matrix_norm, vector_norm from singd.structures.base import StructuredMatrix from singd.structures.utils import is_half_precision, supported_eye, supported_matmul @@ -582,6 +582,22 @@ def test_infinity_vector_norm(self, dev: device, dtype: torch.dtype): structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) report_nonclose(truth, structured.infinity_vector_norm()) + @mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS) + @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) + def test_frobenius_norm(self, dev: device, dtype: torch.dtype): + """Test Frobenius norm of a structured matrix. + + Args: + dev: The device on which to run the test. + dtype: The data type of the matrices. + """ + for dim in self.DIMS: + manual_seed(0) + sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype)) + truth = matrix_norm(self.project(sym_mat)) + structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) + report_nonclose(truth, structured.frobenius_norm()) + @mark.expensive def test_visual(self): """Create pictures and animations of the structure.