Skip to content

Commit

Permalink
Remove CUDA synchronizations by slicing input tensor with int inste…
Browse files Browse the repository at this point in the history
…ad of CUDA tensors in `nn.LinearEmbeddingEncoder` (#432)

`start_idx` and `end_idx` used at `feat.values[:, start_idx:end_idx]`
are on device, which leads to cuda synchronizations.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
akihironitta and pre-commit-ci[bot] authored Aug 12, 2024
1 parent 34ccf7d commit 1f4c4b8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Removed CUDA synchronizations in `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432))

## [0.2.3] - 2024-07-08

### Added
Expand Down
14 changes: 8 additions & 6 deletions torch_frame/nn/encoder/stype_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,10 +716,12 @@ def __init__(
def init_modules(self) -> None:
super().init_modules()
num_cols = len(self.stats_list)
emb_dim_list = [stats[StatType.EMB_DIM] for stats in self.stats_list]
self.emb_dim_list = [
stats[StatType.EMB_DIM] for stats in self.stats_list
]
self.weight_list = ParameterList([
Parameter(torch.empty(emb_dim, self.out_channels))
for emb_dim in emb_dim_list
for emb_dim in self.emb_dim_list
])
self.biases = Parameter(torch.empty(num_cols, self.out_channels))
self.reset_parameters()
Expand All @@ -736,12 +738,12 @@ def encode_forward(
col_names: list[str] | None = None,
) -> Tensor:
x_lins: list[Tensor] = []
for start_idx, end_idx, weight in zip(feat.offset[:-1],
feat.offset[1:],
self.weight_list):
start_idx = 0
for idx, col_dim in enumerate(self.emb_dim_list):
end_idx = start_idx + col_dim
# [batch_size, emb_dim] * [emb_dim, out_channels]
# -> [batch_size, out_channels]
x_lin = torch.matmul(feat.values[:, start_idx:end_idx], weight)
x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx]
x_lins.append(x_lin)
# [batch_size, num_cols, out_channels]
x = torch.stack(x_lins, dim=1)
Expand Down

0 comments on commit 1f4c4b8

Please sign in to comment.