Skip to content

Commit

Permalink
hell yeah last one
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Oct 26, 2023
1 parent d0d7c98 commit 8124b44
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions cornac/models/vbpr/recom_vbpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fit(self, train_set, val_set=None):
raise CornacException("item_image modality is required but None.")

# Item visual feature from CNN
train_features = train_set.item_image.features[: self.train_set.total_items]
train_features = train_set.item_image.features[: train_set.total_items]
train_features = train_features.astype(np.float32)
self._init(
n_users=train_set.total_users,
Expand All @@ -163,11 +163,11 @@ def fit(self, train_set, val_set=None):
)

if self.trainable:
self._fit_torch(train_features)
self._fit_torch(train_set, train_features)

return self

def _fit_torch(self, train_features):
def _fit_torch(self, train_set, train_features):
import torch

def _l2_loss(*tensors):
Expand Down Expand Up @@ -213,11 +213,11 @@ def _inner(a, b):
sum_loss = 0.0
count = 0
progress_bar = tqdm(
total=self.train_set.num_batches(self.batch_size),
total=train_set.num_batches(self.batch_size),
desc="Epoch {}/{}".format(epoch, self.n_epochs),
disable=not self.verbose,
)
for batch_u, batch_i, batch_j in self.train_set.uij_iter(
for batch_u, batch_i, batch_j in train_set.uij_iter(
self.batch_size, shuffle=True
):
gamma_u = Gu[batch_u]
Expand Down

0 comments on commit 8124b44

Please sign in to comment.