Skip to content

Commit

Permalink
try without tensor subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Aug 31, 2024
1 parent 8a0ff41 commit b602f1c
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions torchhd/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,33 @@


class TestSparseDistributed:
# def test_shape(self):
# mem = memory.SparseDistributed(1000, 67, 123)
def test_shape(self):
mem = memory.SparseDistributed(1000, 67, 123)

# keys = torchhd.random(1, 67).squeeze(0)
# values = torchhd.random(1, 123).squeeze(0)
keys = torch.randn(1, 67).squeeze(0).sign()
values = torch.randn(1, 123).squeeze(0).sign()

# mem.write(keys, values)
mem.write(keys, values)

# read = mem.read(keys).sign()
read = mem.read(keys).sign()

# assert read.shape == values.shape
assert read.shape == values.shape

# if torch.allclose(read, values):
# pass
# elif torch.allclose(read, torch.zeros_like(values)):
# pass
# else:
# assert False, "must be either the value or zero"
if torch.allclose(read, values):
pass
elif torch.allclose(read, torch.zeros_like(values)):
pass
else:
assert False, "must be either the value or zero"

def test_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mem = memory.SparseDistributed(1000, 35, 74, kappa=3)
mem = mem.to(device)

keys = torchhd.random(5, 35, device=device)
values = torchhd.random(5, 74, device=device)
keys = torch.randn(5, 35, device=device).sign()
values = torch.randn(5, 74, device=device).sign()

mem.write(keys, values)

Expand Down

0 comments on commit b602f1c

Please sign in to comment.