Skip to content

Commit

Permalink
Allow custom dtype for some VSA similarity functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Sep 26, 2023
1 parent 83fdf34 commit 91bf63f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 17 deletions.
8 changes: 4 additions & 4 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def hard_quantize(input: Tensor):
return torch.where(input > 0, positive, negative)


def dot_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
def dot_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
"""Dot product between the input vector and each vector in others.
Aliased as ``torchhd.dot``.
Expand Down Expand Up @@ -938,13 +938,13 @@ def dot_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
"""
input = ensure_vsa_tensor(input)
others = ensure_vsa_tensor(others)
return input.dot_similarity(others)
return input.dot_similarity(others, **kwargs)


dot = dot_similarity


def cosine_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
def cosine_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
"""Cosine similarity between the input vector and each vector in others.
Aliased as ``torchhd.cos``.
Expand Down Expand Up @@ -987,7 +987,7 @@ def cosine_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
"""
input = ensure_vsa_tensor(input)
others = ensure_vsa_tensor(others)
return input.cosine_similarity(others)
return input.cosine_similarity(others, **kwargs)


cos = cosine_similarity
Expand Down
9 changes: 5 additions & 4 deletions torchhd/tensors/bsbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,21 @@ def permute(self, shifts: int = 1) -> "BSBCTensor":
"""
return torch.roll(self, shifts=shifts, dims=-1)

def dot_similarity(self, others: "BSBCTensor") -> Tensor:
def dot_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
"""Inner product with other hypervectors"""
dtype = torch.get_default_dtype()
if dtype is None:
dtype = torch.get_default_dtype()

if self.dim() > 1 and others.dim() > 1:
equals = self.unsqueeze(-2) == others.unsqueeze(-3)
return torch.sum(equals, dim=-1, dtype=dtype)

return torch.sum(self == others, dim=-1, dtype=dtype)

def cosine_similarity(self, others: "BSBCTensor") -> Tensor:
def cosine_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
"""Cosine similarity with other hypervectors"""
magnitude = self.size(-1)
return self.dot_similarity(others) / magnitude
return self.dot_similarity(others, dtype=dtype) / magnitude

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down
9 changes: 5 additions & 4 deletions torchhd/tensors/bsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,11 @@ def permute(self, shifts: int = 1) -> "BSCTensor":
"""
return super().roll(shifts=shifts, dims=-1)

def dot_similarity(self, others: "BSCTensor") -> Tensor:
def dot_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
"""Inner product with other hypervectors."""
dtype = torch.get_default_dtype()
device = self.device
if dtype is None:
dtype = torch.get_default_dtype()

min_one = torch.tensor(-1.0, dtype=dtype, device=device)
plus_one = torch.tensor(1.0, dtype=dtype, device=device)
Expand All @@ -441,7 +442,7 @@ def dot_similarity(self, others: "BSCTensor") -> Tensor:
others_as_bipolar = others_as_bipolar.transpose(-2, -1)
return torch.matmul(self_as_bipolar, others_as_bipolar)

def cosine_similarity(self, others: "BSCTensor") -> Tensor:
def cosine_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
"""Cosine similarity with other hypervectors."""
d = self.size(-1)
return self.dot_similarity(others) / d
return self.dot_similarity(others, dtype=dtype) / d
1 change: 1 addition & 0 deletions torchhd/tensors/fhrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def dot_similarity(self, others: "FHRRTensor") -> Tensor:
"""Inner product with other hypervectors"""
if others.dim() >= 2:
others = others.transpose(-2, -1)

return torch.real(torch.matmul(self, torch.conj(others)))

def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:
Expand Down
1 change: 1 addition & 0 deletions torchhd/tensors/hrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def dot_similarity(self, others: "HRRTensor") -> Tensor:
"""Inner product with other hypervectors"""
if others.dim() >= 2:
others = others.transpose(-2, -1)

return torch.matmul(self, others)

def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:
Expand Down
14 changes: 9 additions & 5 deletions torchhd/tensors/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,20 @@ def clipping(self, kappa) -> "MAPTensor":

return torch.clamp(self, min=-kappa, max=kappa)

def dot_similarity(self, others: "MAPTensor") -> Tensor:
def dot_similarity(self, others: "MAPTensor", *, dtype=None) -> Tensor:
"""Inner product with other hypervectors"""
dtype = torch.get_default_dtype()
if dtype is None:
dtype = torch.get_default_dtype()

if others.dim() >= 2:
others = others.transpose(-2, -1)

return torch.matmul(self.to(dtype), others.to(dtype))

def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:
def cosine_similarity(self, others: "MAPTensor", *, dtype=None, eps=1e-08) -> Tensor:
"""Cosine similarity with other hypervectors"""
dtype = torch.get_default_dtype()
if dtype is None:
dtype = torch.get_default_dtype()

self_dot = torch.sum(self * self, dim=-1, dtype=dtype)
self_mag = torch.sqrt(self_dot)
Expand All @@ -363,4 +367,4 @@ def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:
magnitude = self_mag * others_mag

magnitude = torch.clamp(magnitude, min=eps)
return self.dot_similarity(others) / magnitude
return self.dot_similarity(others, dtype=dtype) / magnitude
32 changes: 32 additions & 0 deletions torchhd/tests/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,38 @@ def test_dtype(self, vsa, dtype):
else:
assert similarity.dtype == torch.get_default_dtype()

def test_custom_dtype(self):
hv = functional.random(3, 100, "BSBC", block_size=1024)
similarity = functional.dot_similarity(hv, hv)
assert similarity.dtype == torch.get_default_dtype()

similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
assert similarity.dtype == torch.float64

similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
assert similarity.dtype == torch.int16

hv = functional.random(3, 100, "MAP")
similarity = functional.dot_similarity(hv, hv)
assert similarity.dtype == torch.get_default_dtype()

similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
assert similarity.dtype == torch.float64

similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
assert similarity.dtype == torch.int16

hv = functional.random(3, 100, "BSC")
similarity = functional.dot_similarity(hv, hv)
assert similarity.dtype == torch.get_default_dtype()

similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
assert similarity.dtype == torch.float64

similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
assert similarity.dtype == torch.int16


@pytest.mark.parametrize("vsa", vsa_tensors)
@pytest.mark.parametrize("dtype", torch_dtypes)
def test_device(self, vsa, dtype):
Expand Down

0 comments on commit 91bf63f

Please sign in to comment.