From 0ac1add477c5c1a94bc111cc1f7650f8b8890651 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 20:14:29 -0800 Subject: [PATCH] Add test --- torchhd/models.py | 1 - torchhd/tests/test_models.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torchhd/models.py b/torchhd/models.py index d806f657..af6d7b2b 100644 --- a/torchhd/models.py +++ b/torchhd/models.py @@ -161,7 +161,6 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: self.weight.index_add_(0, target, alpha1 * input, alpha=lr) self.weight.index_add_(0, pred, alpha2 * input, alpha=lr) - @torch.no_grad() def normalize(self, eps=1e-12) -> None: """Transforms all the class prototype vectors into unit vectors. diff --git a/torchhd/tests/test_models.py b/torchhd/tests/test_models.py index e4226f50..93a9eca7 100644 --- a/torchhd/tests/test_models.py +++ b/torchhd/tests/test_models.py @@ -82,6 +82,16 @@ def test_add_online(self): logits = model(samples) assert logits.shape == (10, 3) + def test_add_adapt(self): + samples = torch.randn(10, 12) + targets = torch.randint(0, 3, (10,)) + + model = models.Centroid(12, 3) + model.add_adapt(samples, targets) + + logits = model(samples) + assert logits.shape == (10, 3) + class TestIntRVFL: @pytest.mark.parametrize("dtype", torch_dtypes)