Skip to content

Commit

Permalink
Fix numpy arrays warning
Browse files Browse the repository at this point in the history
  • Loading branch information
milad2073 committed Nov 6, 2024
1 parent 0d9878c commit 009a50c
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/learning_with_hrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchhd import embeddings, HRRTensor
import torchhd.tensors
from scipy.sparse import vstack, lil_matrix
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -36,7 +37,7 @@ def sparse_batch_collate(batch:list):
data_batch, targets_batch = zip(*batch)

data_batch = vstack(data_batch).tocoo()
data_batch = torch.sparse_coo_tensor(data_batch.nonzero(), data_batch.data, data_batch.shape)
data_batch = torch.sparse_coo_tensor(np.array(data_batch.nonzero()), data_batch.data, data_batch.shape)

targets_batch = torch.stack(targets_batch)

Expand Down Expand Up @@ -67,7 +68,7 @@ def __getitem__(self, idx):
if DATASET_NAME == "Wiki10-31K": # Because of this issue https://github.com/mwydmuch/napkinXC/issues/18
X_train = lil_matrix(X_train[:,:-1])

N_freatures = X_train.shape[1]
N_features = X_train.shape[1]
N_classes = max(max(classes) for classes in Y_train if classes != []) + 1

train_dataset = multilabel_dataset(X_train,Y_train,N_classes)
Expand All @@ -77,7 +78,7 @@ def __getitem__(self, idx):


print("Traning on \033[1m {} \033[0m. It has {} features, and {} classes."
.format(DATASET_NAME,N_freatures,N_classes))
.format(DATASET_NAME,N_features,N_classes))


# Fully Connected model for the baseline comparision
Expand Down Expand Up @@ -168,10 +169,10 @@ def loss(self,out,target):



hrr_model = FCHRR(N_freatures,N_classes,DIMENSIONS)
hrr_model = FCHRR(N_features,N_classes,DIMENSIONS)
hrr_model = hrr_model.to(device)

baseline_model = FC(N_freatures,N_classes)
baseline_model = FC(N_features,N_classes)
baseline_model = baseline_model.to(device)


Expand Down

0 comments on commit 009a50c

Please sign in to comment.