Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 6, 2024
1 parent 5a75491 commit f0b9232
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 22 deletions.
22 changes: 11 additions & 11 deletions torchhd/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(idx).dtype == torch.bool
elif vsa == "MAP" or vsa == "HRR":
assert emb(idx).dtype == torch.float
assert emb(idx).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert (
emb(idx).dtype == torch.complex64 or emb(idx).dtype == torch.complex32
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(idx).dtype == torch.bool
elif vsa in {"MAP", "HRR", "VTB"}:
assert emb(idx).dtype == torch.float
assert emb(idx).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert emb(idx).dtype in {torch.complex64, torch.complex32}

Expand Down Expand Up @@ -244,7 +244,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(idx).dtype == torch.bool
elif vsa in {"MAP", "HRR", "VTB"}:
assert emb(idx).dtype == torch.float
assert emb(idx).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert emb(idx).dtype in {torch.complex64, torch.complex32}

Expand Down Expand Up @@ -295,7 +295,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(idx).dtype == torch.bool
elif vsa in {"MAP", "HRR", "VTB"}:
assert emb(idx).dtype == torch.float
assert emb(idx).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert emb(idx).dtype in {torch.complex64, torch.complex32}

Expand Down Expand Up @@ -365,7 +365,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(angle).dtype == torch.bool
elif vsa == "MAP":
assert emb(angle).dtype == torch.float
assert emb(angle).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert (
emb(angle).dtype == torch.complex64
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(angle).dtype == torch.bool
elif vsa == "MAP":
assert emb(angle).dtype == torch.float
assert emb(angle).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert (
emb(angle).dtype == torch.complex64
Expand Down Expand Up @@ -504,7 +504,7 @@ def test_dtype(self, vsa):
emb = embeddings.Projection(in_features, out_features, vsa=vsa)
x = torch.randn(1, in_features)
if vsa == "MAP" or vsa == "HRR":
assert emb(x).dtype == torch.float
assert emb(x).dtype == torch.get_default_dtype()
else:
return

Expand Down Expand Up @@ -549,7 +549,7 @@ def test_dtype(self, vsa):
emb = embeddings.Sinusoid(in_features, out_features, vsa=vsa)
x = torch.randn(1, in_features)
if vsa == "MAP" or vsa == "HRR":
assert emb(x).dtype == torch.float
assert emb(x).dtype == torch.get_default_dtype()
else:
return

Expand Down Expand Up @@ -611,7 +611,7 @@ def test_dtype(self, vsa):
if vsa == "BSC":
assert emb(x).dtype == torch.bool
elif vsa == "MAP":
assert emb(x).dtype == torch.float
assert emb(x).dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert emb(x).dtype == torch.complex64 or emb(x).dtype == torch.complex32
else:
Expand Down Expand Up @@ -664,9 +664,9 @@ def test_default_dtype(self, vsa):
assert y.shape == (2, dimensions)

if vsa == "HRR":
assert y.dtype == torch.float32
assert y.dtype == torch.get_default_dtype()
elif vsa == "FHRR":
assert y.dtype == torch.complex64
assert fhrr_type_conversion[y.dtype] == torch.get_default_dtype()
else:
return

Expand Down
8 changes: 0 additions & 8 deletions torchhd/tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ def test_dtype(self, dtype):
hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor)

if dtype in {torch.float16}:
# torch.product is not implemented on CPU for these dtypes
with pytest.raises(RuntimeError):
functional.multibind(hv)

return

res = functional.multibind(hv)
Expand Down Expand Up @@ -288,10 +284,6 @@ def test_dtype(self, dtype):
hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor)

if dtype in {torch.float16}:
# torch.product is not implemented on CPU for these dtypes
with pytest.raises(RuntimeError):
functional.multibind(hv)

return

res = functional.bind_sequence(hv)
Expand Down
6 changes: 3 additions & 3 deletions torchhd/tests/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_value(self, vsa, dtype):
).as_subclass(BSCTensor)

res = functional.dot_similarity(hv, hv)
exp = torch.tensor([[10, 4], [4, 10]], dtype=torch.long)
exp = torch.tensor([[10, 4], [4, 10]], dtype=res.dtype)
assert torch.all(res == exp).item()

elif vsa == "FHRR":
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_value(self, vsa, dtype):
).as_subclass(BSCTensor)

res = functional.cosine_similarity(hv, hv)
exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=torch.float)
exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=res.dtype)
assert torch.allclose(res, exp)

elif vsa == "FHRR":
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_value(self, vsa, dtype):
).as_subclass(BSCTensor)

res = functional.hamming_similarity(hv, hv)
exp = torch.tensor([[10, 7], [7, 10]], dtype=torch.long)
exp = torch.tensor([[10, 7], [7, 10]], dtype=res.dtype)
assert torch.all(res == exp).item()

elif vsa == "FHRR":
Expand Down

0 comments on commit f0b9232

Please sign in to comment.