From 63931ba28d12fea2296759788f1ed9b1ddd0198e Mon Sep 17 00:00:00 2001 From: toenshoff Date: Sat, 21 Sep 2024 01:36:49 +0200 Subject: [PATCH] Fix offset in LinearEmbeddingEncoder (#455) This PR fixes an issue in the `LinearEmbeddingEncoder` class. The forward function does not increment the `start_idx` variable, so only the first one is used correctly if multiple embedding columns are present. This probably affects the results in the [RelBench](https://arxiv.org/pdf/2407.20060) paper, so I suggest double-checking those. --------- Co-authored-by: Akihiro Nitta --- torch_frame/nn/encoder/stype_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 88a0e5be9..1b12c1891 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -732,6 +732,7 @@ def encode_forward( # -> [batch_size, out_channels] x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx] x_lins.append(x_lin) + start_idx = end_idx # [batch_size, num_cols, out_channels] x = torch.stack(x_lins, dim=1) # [batch_size, num_cols, out_channels] + [num_cols, out_channels]