Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 6, 2024
1 parent 4f7135d commit 0ac1add
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
1 change: 0 additions & 1 deletion torchhd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
self.weight.index_add_(0, target, alpha1 * input, alpha=lr)
self.weight.index_add_(0, pred, alpha2 * input, alpha=lr)

@torch.no_grad()
def normalize(self, eps=1e-12) -> None:
"""Transforms all the class prototype vectors into unit vectors.
Expand Down
10 changes: 10 additions & 0 deletions torchhd/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def test_add_online(self):
logits = model(samples)
assert logits.shape == (10, 3)

def test_add_adapt(self):
samples = torch.randn(10, 12)
targets = torch.randint(0, 3, (10,))

model = models.Centroid(12, 3)
model.add_adapt(samples, targets)

logits = model(samples)
assert logits.shape == (10, 3)


class TestIntRVFL:
@pytest.mark.parametrize("dtype", torch_dtypes)
Expand Down

0 comments on commit 0ac1add

Please sign in to comment.