From e958136cb2d9e356b9793efe0f73a16d05696b7d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 11 Aug 2024 14:51:36 +0000 Subject: [PATCH] update --- .github/workflows/testing.yml | 1 + torch_frame/nn/encoder/stype_encoder.py | 93 ++++++++++--------------- 2 files changed, 38 insertions(+), 56 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index f7eede60e..e88201f06 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -4,6 +4,7 @@ on: # yamllint disable-line rule:truthy push: branches: - master + - aki-rm-cuda-sync0 pull_request: jobs: diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 7312a7ddf..a07ab4527 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -109,6 +109,28 @@ def init_modules(self): f"can be used on {self.stype} columns, but " f"{self.na_strategy} is given.") + fill_values = [] + for col in range(len(self.stats_list)): + if self.na_strategy == NAStrategy.MOST_FREQUENT: + # Categorical index is sorted based on count, + # so 0-th index is always the most frequent. + fill_value = 0 + elif self.na_strategy == NAStrategy.MEAN: + fill_value = self.stats_list[col][StatType.MEAN] + elif self.na_strategy == NAStrategy.ZEROS: + fill_value = 0 + elif self.na_strategy == NAStrategy.NEWEST_TIMESTAMP: + fill_value = self.stats_list[col][StatType.NEWEST_TIME] + elif self.na_strategy == NAStrategy.OLDEST_TIMESTAMP: + fill_value = self.stats_list[col][StatType.OLDEST_TIME] + elif self.na_strategy == NAStrategy.MEDIAN_TIMESTAMP: + fill_value = self.stats_list[col][StatType.MEDIAN_TIME] + else: + raise ValueError(f"Unsupported NA strategy {self.na_strategy}") + fill_values.append(fill_value) + + self.register_buffer("fill_values", torch.tensor(fill_values)) + @abstractmethod def reset_parameters(self): r"""Initialize the parameters of `post_module`.""" @@ -190,79 +212,38 @@ def na_forward(self, feat: TensorData) -> TensorData: if isinstance(feat, Tensor): # cache for future use na_mask = get_na_mask(feat) - if na_mask.any(): - feat = feat.clone() - else: - return feat + feat = feat.clone() elif isinstance(feat, MultiEmbeddingTensor): - if get_na_mask(feat.values).any(): - feat = MultiEmbeddingTensor(num_rows=feat.num_rows, - num_cols=feat.num_cols, - values=feat.values.clone(), - offset=feat.offset) - else: - return feat + feat = MultiEmbeddingTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) elif isinstance(feat, MultiNestedTensor): - if get_na_mask(feat.values).any(): - feat = MultiNestedTensor(num_rows=feat.num_rows, - num_cols=feat.num_cols, - values=feat.values.clone(), - offset=feat.offset) - else: - return feat + feat = MultiNestedTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) else: raise ValueError(f"Unrecognized type {type(feat)} in na_forward.") - fill_values = [] - for col in range(feat.size(1)): - if self.na_strategy == NAStrategy.MOST_FREQUENT: - # Categorical index is sorted based on count, - # so 0-th index is always the most frequent. - fill_value = 0 - elif self.na_strategy == NAStrategy.MEAN: - fill_value = self.stats_list[col][StatType.MEAN] - elif self.na_strategy == NAStrategy.ZEROS: - fill_value = 0 - elif self.na_strategy == NAStrategy.NEWEST_TIMESTAMP: - fill_value = self.stats_list[col][StatType.NEWEST_TIME].to( - feat.device) - elif self.na_strategy == NAStrategy.OLDEST_TIMESTAMP: - fill_value = self.stats_list[col][StatType.OLDEST_TIME].to( - feat.device) - elif self.na_strategy == NAStrategy.MEDIAN_TIMESTAMP: - fill_value = self.stats_list[col][StatType.MEDIAN_TIME].to( - feat.device) - else: - raise ValueError(f"Unsupported NA strategy {self.na_strategy}") - fill_values.append(fill_value) - if isinstance(feat, _MultiTensor): - for col, fill_value in enumerate(fill_values): + for col, fill_value in enumerate(self.fill_values): feat.fillna_col(col, fill_value) else: if na_mask.ndim == 3: # when feat is 3D, it is faster to iterate over columns - for col, fill_value in enumerate(fill_values): + for col, fill_value in enumerate(self.fill_values): col_data = feat[:, col] col_na_mask = na_mask[:, col].any(dim=-1) col_data[col_na_mask] = fill_value else: # na_mask.ndim == 2 - fill_values = torch.tensor(fill_values, device=feat.device) - assert feat.size(-1) == fill_values.size(-1) - feat = torch.where(na_mask, fill_values, feat) - # Add better safeguard here to make sure nans are actually - # replaced, expecially when nans are represented as -1's. They are - # very hard to catch as they won't error out. - filled_values = feat - if isinstance(feat, _MultiTensor): - filled_values = feat.values - if filled_values.is_floating_point(): - assert not torch.isnan(filled_values).any() - else: - assert not (filled_values == -1).any() + assert feat.size(-1) == self.fill_values.size(-1) + feat = torch.where(na_mask, self.fill_values, feat) + return feat + class EmbeddingEncoder(StypeEncoder): r"""An embedding look-up based encoder for categorical features. It applies :class:`torch.nn.Embedding` for each categorical feature and