From f0b9232b5b10b26cc455d7e72f38e57fbac8a2fd Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 21:18:51 -0800 Subject: [PATCH] Update tests --- torchhd/tests/test_embeddings.py | 22 +++++++++++----------- torchhd/tests/test_encodings.py | 8 -------- torchhd/tests/test_similarities.py | 6 +++--- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/torchhd/tests/test_embeddings.py b/torchhd/tests/test_embeddings.py index a9abda34..17b6362f 100644 --- a/torchhd/tests/test_embeddings.py +++ b/torchhd/tests/test_embeddings.py @@ -74,7 +74,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa == "MAP" or vsa == "HRR": - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(idx).dtype == torch.complex64 or emb(idx).dtype == torch.complex32 @@ -142,7 +142,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -244,7 +244,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -295,7 +295,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(idx).dtype == torch.bool elif vsa in {"MAP", "HRR", "VTB"}: - assert emb(idx).dtype == torch.float + assert emb(idx).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(idx).dtype in {torch.complex64, torch.complex32} @@ -365,7 +365,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(angle).dtype == torch.bool elif vsa == "MAP": - assert emb(angle).dtype == torch.float + assert emb(angle).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(angle).dtype == torch.complex64 @@ -441,7 +441,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(angle).dtype == torch.bool elif vsa == "MAP": - assert emb(angle).dtype == torch.float + assert emb(angle).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert ( emb(angle).dtype == torch.complex64 @@ -504,7 +504,7 @@ def test_dtype(self, vsa): emb = embeddings.Projection(in_features, out_features, vsa=vsa) x = torch.randn(1, in_features) if vsa == "MAP" or vsa == "HRR": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() else: return @@ -549,7 +549,7 @@ def test_dtype(self, vsa): emb = embeddings.Sinusoid(in_features, out_features, vsa=vsa) x = torch.randn(1, in_features) if vsa == "MAP" or vsa == "HRR": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() else: return @@ -611,7 +611,7 @@ def test_dtype(self, vsa): if vsa == "BSC": assert emb(x).dtype == torch.bool elif vsa == "MAP": - assert emb(x).dtype == torch.float + assert emb(x).dtype == torch.get_default_dtype() elif vsa == "FHRR": assert emb(x).dtype == torch.complex64 or emb(x).dtype == torch.complex32 else: @@ -664,9 +664,9 @@ def test_default_dtype(self, vsa): assert y.shape == (2, dimensions) if vsa == "HRR": - assert y.dtype == torch.float32 + assert y.dtype == torch.get_default_dtype() elif vsa == "FHRR": - assert y.dtype == torch.complex64 + assert fhrr_type_conversion[y.dtype] == torch.get_default_dtype() else: return diff --git a/torchhd/tests/test_encodings.py b/torchhd/tests/test_encodings.py index 927993b3..af205bb9 100644 --- a/torchhd/tests/test_encodings.py +++ b/torchhd/tests/test_encodings.py @@ -141,10 +141,6 @@ def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor) if dtype in {torch.float16}: - # torch.product is not implemented on CPU for these dtypes - with pytest.raises(RuntimeError): - functional.multibind(hv) - return res = functional.multibind(hv) @@ -288,10 +284,6 @@ def test_dtype(self, dtype): hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor) if dtype in {torch.float16}: - # torch.product is not implemented on CPU for these dtypes - with pytest.raises(RuntimeError): - functional.multibind(hv) - return res = functional.bind_sequence(hv) diff --git a/torchhd/tests/test_similarities.py b/torchhd/tests/test_similarities.py index eb104885..d33c0b48 100644 --- a/torchhd/tests/test_similarities.py +++ b/torchhd/tests/test_similarities.py @@ -118,7 +118,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.dot_similarity(hv, hv) - exp = torch.tensor([[10, 4], [4, 10]], dtype=torch.long) + exp = torch.tensor([[10, 4], [4, 10]], dtype=res.dtype) assert torch.all(res == exp).item() elif vsa == "FHRR": @@ -339,7 +339,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.cosine_similarity(hv, hv) - exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=torch.float) + exp = torch.tensor([[1, 0.4], [0.4, 1]], dtype=res.dtype) assert torch.allclose(res, exp) elif vsa == "FHRR": @@ -529,7 +529,7 @@ def test_value(self, vsa, dtype): ).as_subclass(BSCTensor) res = functional.hamming_similarity(hv, hv) - exp = torch.tensor([[10, 7], [7, 10]], dtype=torch.long) + exp = torch.tensor([[10, 7], [7, 10]], dtype=res.dtype) assert torch.all(res == exp).item() elif vsa == "FHRR":