From 8124b446448171c77d4b810ab51cc827fe4c2fd5 Mon Sep 17 00:00:00 2001 From: tqtg Date: Thu, 26 Oct 2023 06:19:32 +0000 Subject: [PATCH] hell yeah last one --- cornac/models/vbpr/recom_vbpr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cornac/models/vbpr/recom_vbpr.py b/cornac/models/vbpr/recom_vbpr.py index 19c38d405..f8c8ea358 100644 --- a/cornac/models/vbpr/recom_vbpr.py +++ b/cornac/models/vbpr/recom_vbpr.py @@ -154,7 +154,7 @@ def fit(self, train_set, val_set=None): raise CornacException("item_image modality is required but None.") # Item visual feature from CNN - train_features = train_set.item_image.features[: self.train_set.total_items] + train_features = train_set.item_image.features[: train_set.total_items] train_features = train_features.astype(np.float32) self._init( n_users=train_set.total_users, @@ -163,11 +163,11 @@ def fit(self, train_set, val_set=None): ) if self.trainable: - self._fit_torch(train_features) + self._fit_torch(train_set, train_features) return self - def _fit_torch(self, train_features): + def _fit_torch(self, train_set, train_features): import torch def _l2_loss(*tensors): @@ -213,11 +213,11 @@ def _inner(a, b): sum_loss = 0.0 count = 0 progress_bar = tqdm( - total=self.train_set.num_batches(self.batch_size), + total=train_set.num_batches(self.batch_size), desc="Epoch {}/{}".format(epoch, self.n_epochs), disable=not self.verbose, ) - for batch_u, batch_i, batch_j in self.train_set.uij_iter( + for batch_u, batch_i, batch_j in train_set.uij_iter( self.batch_size, shuffle=True ): gamma_u = Gu[batch_u]