Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
weihua916 committed May 1, 2024
1 parent 2f1a876 commit 7177e92
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
5 changes: 5 additions & 0 deletions test/transforms/test_cat_to_num_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def test_cat_to_num_transform_on_categorical_only_dataset(with_nan):
out.col_names_dict[stype.numerical]) == ((dataset.num_classes - 1) *
total_cols))

tensor_frame.feat_dict[stype.categorical] += 1
with pytest.raises(RuntimeError, match="contains new category"):
# Raise informative error when input tensor frame contains new category
out = transform(tensor_frame)


@pytest.mark.parametrize('task_type', [
TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION,
Expand Down
8 changes: 6 additions & 2 deletions torch_frame/transforms/cat_to_num_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ def _forward(self, tf: TensorFrame) -> TensorFrame:
count = torch.tensor(self.col_stats[col_name][StatType.COUNT][1],
device=tf.device)
feat = tensor[:, i]
v = torch.index_select(count, 0, feat).unsqueeze(1).repeat(
1, num_classes - 1)
max_cat = feat.max()
if max_cat >= len(count):
raise RuntimeError(
f"{col_name} contains new category {max_cat} not seen "
f"during fit stage.")
v = count[feat].unsqueeze(1).repeat(1, num_classes - 1)
transformed_tensor[:, i * (num_classes - 1):(i + 1) *
(num_classes - 1)] = ((v + target_mean) /
(self.data_size + 1))
Expand Down

0 comments on commit 7177e92

Please sign in to comment.