Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 23, 2024
1 parent 830ec23 commit ac95d3f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## \[Unreleased\]

### Added

- Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470))
- Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464))
- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444))
Expand Down
8 changes: 4 additions & 4 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def materialize(
:obj:`path`. If :obj:`path` is :obj:`None`, this will
materialize the dataset without caching.
(default: :obj:`None`)
col_stats (Dict[str, Dict[StatType, Any]], optional): optional
col_stats (Dict[str, Dict[StatType, Any]], optional): optional
col_stats provided by the user. If not provided, the statistics
is calculated from the dataframe itself. (default: :obj:`None`)
Expand Down Expand Up @@ -604,9 +604,9 @@ def materialize(
sep=self.col_to_sep.get(col, None),
time_format=self.col_to_time_format.get(col, None),
)
# For a target column, sort categories lexicographically such that
# we do not accidentally swap labels in binary classification
# tasks.
# For a target column, sort categories lexicographically such that
# we do not accidentally swap labels in binary classification
# tasks.
if col == self.target_col and stype == torch_frame.categorical:
index, value = self._col_stats[col][StatType.COUNT]
if len(index) == 2:
Expand Down

0 comments on commit ac95d3f

Please sign in to comment.