From ac95d3f464d2e19389f29de09f12423546a514d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 02:34:15 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 1 + torch_frame/data/dataset.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 770fd3a5..60d39093 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[Unreleased\] ### Added + - Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470)) - Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464)) - Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444)) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 8b84de0b..ddf31051 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -571,7 +571,7 @@ def materialize( :obj:`path`. If :obj:`path` is :obj:`None`, this will materialize the dataset without caching. (default: :obj:`None`) - col_stats (Dict[str, Dict[StatType, Any]], optional): optional + col_stats (Dict[str, Dict[StatType, Any]], optional): optional col_stats provided by the user. If not provided, the statistics is calculated from the dataframe itself. (default: :obj:`None`) @@ -604,9 +604,9 @@ def materialize( sep=self.col_to_sep.get(col, None), time_format=self.col_to_time_format.get(col, None), ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. + # For a target column, sort categories lexicographically such that + # we do not accidentally swap labels in binary classification + # tasks. if col == self.target_col and stype == torch_frame.categorical: index, value = self._col_stats[col][StatType.COUNT] if len(index) == 2: