Skip to content

Commit

Permalink
Fix offset in LinearEmbeddingEncoder (#455)
Browse files Browse the repository at this point in the history
This PR fixes an issue in the `LinearEmbeddingEncoder` class. The
forward function does not increment the `start_idx` variable, so only
the first one is used correctly if multiple embedding columns are
present.

This probably affects the results in the
[RelBench](https://arxiv.org/pdf/2407.20060) paper, so I suggest
double-checking those.

---------

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
toenshoff and akihironitta authored Sep 20, 2024
1 parent a3b73c4 commit 63931ba
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torch_frame/nn/encoder/stype_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def encode_forward(
# -> [batch_size, out_channels]
x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx]
x_lins.append(x_lin)
start_idx = end_idx
# [batch_size, num_cols, out_channels]
x = torch.stack(x_lins, dim=1)
# [batch_size, num_cols, out_channels] + [num_cols, out_channels]
Expand Down

0 comments on commit 63931ba

Please sign in to comment.