diff --git a/CHANGELOG.md b/CHANGELOG.md index e0c56a9e..5d57fc9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Removed CUDA synchronizations in `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432)) + ## [0.2.3] - 2024-07-08 ### Added diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 3ae34696..7312a7dd 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -716,10 +716,12 @@ def __init__( def init_modules(self) -> None: super().init_modules() num_cols = len(self.stats_list) - emb_dim_list = [stats[StatType.EMB_DIM] for stats in self.stats_list] + self.emb_dim_list = [ + stats[StatType.EMB_DIM] for stats in self.stats_list + ] self.weight_list = ParameterList([ Parameter(torch.empty(emb_dim, self.out_channels)) - for emb_dim in emb_dim_list + for emb_dim in self.emb_dim_list ]) self.biases = Parameter(torch.empty(num_cols, self.out_channels)) self.reset_parameters() @@ -736,12 +738,12 @@ def encode_forward( col_names: list[str] | None = None, ) -> Tensor: x_lins: list[Tensor] = [] - for start_idx, end_idx, weight in zip(feat.offset[:-1], - feat.offset[1:], - self.weight_list): + start_idx = 0 + for idx, col_dim in enumerate(self.emb_dim_list): + end_idx = start_idx + col_dim # [batch_size, emb_dim] * [emb_dim, out_channels] # -> [batch_size, out_channels] - x_lin = torch.matmul(feat.values[:, start_idx:end_idx], weight) + x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx] x_lins.append(x_lin) # [batch_size, num_cols, out_channels] x = torch.stack(x_lins, dim=1)