Skip to content

Commit

Permalink
Ensure hamming similarity inputs are vsa tensors (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes authored Sep 26, 2023
1 parent 83fdf34 commit 2de23fd
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,9 @@ def hamming_similarity(input: VSATensor, others: VSATensor) -> LongTensor:
[5, 3, 6]])
"""
input = ensure_vsa_tensor(input)
others = ensure_vsa_tensor(others)

if input.dim() > 1 and others.dim() > 1:
equals = input.unsqueeze(-2) == others.unsqueeze(-3)
return torch.sum(equals, dim=-1, dtype=torch.long)
Expand Down

0 comments on commit 2de23fd

Please sign in to comment.