From 1f4c4b868ba888866041183aef36620c3ca932a1 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 12 Aug 2024 18:02:52 +0900 Subject: [PATCH] Remove CUDA synchronizations by slicing input tensor with `int` instead of CUDA tensors in `nn.LinearEmbeddingEncoder` (#432) `start_idx` and `end_idx` used at `feat.values[:, start_idx:end_idx]` are on device, which leads to cuda synchronizations. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 ++ torch_frame/nn/encoder/stype_encoder.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0c56a9e..5d57fc9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Removed CUDA synchronizations in `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432)) + ## [0.2.3] - 2024-07-08 ### Added diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 3ae34696..7312a7dd 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -716,10 +716,12 @@ def __init__( def init_modules(self) -> None: super().init_modules() num_cols = len(self.stats_list) - emb_dim_list = [stats[StatType.EMB_DIM] for stats in self.stats_list] + self.emb_dim_list = [ + stats[StatType.EMB_DIM] for stats in self.stats_list + ] self.weight_list = ParameterList([ Parameter(torch.empty(emb_dim, self.out_channels)) - for emb_dim in emb_dim_list + for emb_dim in self.emb_dim_list ]) self.biases = Parameter(torch.empty(num_cols, self.out_channels)) self.reset_parameters() @@ -736,12 +738,12 @@ def encode_forward( col_names: list[str] | None = None, ) -> Tensor: x_lins: list[Tensor] = [] - for start_idx, end_idx, weight in zip(feat.offset[:-1], - feat.offset[1:], - self.weight_list): + start_idx = 0 + for idx, col_dim in enumerate(self.emb_dim_list): + end_idx = start_idx + col_dim # [batch_size, emb_dim] * [emb_dim, out_channels] # -> [batch_size, out_channels] - x_lin = torch.matmul(feat.values[:, start_idx:end_idx], weight) + x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx] x_lins.append(x_lin) # [batch_size, num_cols, out_channels] x = torch.stack(x_lins, dim=1)