From 32e7de285e9c0270284c0c741217a566491d6f0e Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Fri, 30 Aug 2024 15:10:45 -0700 Subject: [PATCH] Add third python version --- .github/workflows/test.yml | 2 +- torchhd/memory.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 66c230a5..67b1ba77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: timeout-minutes: 20 strategy: matrix: - python-version: ['3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] os: [ubuntu-latest, windows-latest, macos-latest] steps: diff --git a/torchhd/memory.py b/torchhd/memory.py index 544c3d5d..9d7e7fe0 100644 --- a/torchhd/memory.py +++ b/torchhd/memory.py @@ -121,13 +121,12 @@ def read(self, query: Tensor) -> VSATensor: """ # first dims from query, last dim from value - out_shape = (*query.shape[:-1], self.value_dim) + out_shape = tuple(query.shape[:-1]) + (self.value_dim,) if query.dim() == 1: query = query.unsqueeze(0) - # make sure to have at least two dimension for index_add_ - intermediate_shape = (*query.shape[:-1], self.value_dim) + intermediate_shape = tuple(query.shape[:-1]) + (self.value_dim,) similarity = query @ self.keys.T is_active = similarity >= self.threshold @@ -135,7 +134,7 @@ def read(self, query: Tensor) -> VSATensor: # sparse matrix-vector multiplication r_indices, v_indices = is_active.nonzero().T read = query.new_zeros(intermediate_shape) - read.index_add_(0, r_indices, self.values[v_indices]) + read = read.index_add(0, r_indices, self.values[v_indices]) return read.view(out_shape) @torch.no_grad()