Skip to content

Commit

Permalink
Add third python version
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Aug 30, 2024
1 parent cd7b40a commit 32e7de2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
Expand Down
7 changes: 3 additions & 4 deletions torchhd/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,20 @@ def read(self, query: Tensor) -> VSATensor:
"""
# first dims from query, last dim from value
out_shape = (*query.shape[:-1], self.value_dim)
out_shape = tuple(query.shape[:-1]) + (self.value_dim,)

if query.dim() == 1:
query = query.unsqueeze(0)

# make sure to have at least two dimension for index_add_
intermediate_shape = (*query.shape[:-1], self.value_dim)
intermediate_shape = tuple(query.shape[:-1]) + (self.value_dim,)

similarity = query @ self.keys.T
is_active = similarity >= self.threshold

# sparse matrix-vector multiplication
r_indices, v_indices = is_active.nonzero().T
read = query.new_zeros(intermediate_shape)
read.index_add_(0, r_indices, self.values[v_indices])
read = read.index_add(0, r_indices, self.values[v_indices])
return read.view(out_shape)

@torch.no_grad()
Expand Down

0 comments on commit 32e7de2

Please sign in to comment.