From 3d9af8f8350616a59210544b5c9181c307dbb80e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:47:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sdgx/data_models/metadata.py | 2 +- sdgx/models/components/optimize/ndarray_loader.py | 3 ++- .../optimize/sdv_ctgan/data_transformer.py | 13 ++++++++----- sdgx/models/ml/single_table/ctgan.py | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/sdgx/data_models/metadata.py b/sdgx/data_models/metadata.py index 6ea4c9d3..5f8ffa41 100644 --- a/sdgx/data_models/metadata.py +++ b/sdgx/data_models/metadata.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from itertools import chain from pathlib import Path -from typing import Any, Dict, List, Set, Literal, Union +from typing import Any, Dict, List, Literal, Set, Union import pandas as pd from pydantic import BaseModel, Field, field_validator diff --git a/sdgx/models/components/optimize/ndarray_loader.py b/sdgx/models/components/optimize/ndarray_loader.py index 1fe2f28a..7a537e8f 100644 --- a/sdgx/models/components/optimize/ndarray_loader.py +++ b/sdgx/models/components/optimize/ndarray_loader.py @@ -37,7 +37,8 @@ def __init__(self, cache_root: str | Path = DEFAULT_CACHE_ROOT, save_to_file=Tru def get_auto_save(raw_data) -> NDArrayLoader: save_to_file = True if isinstance(raw_data, pd.DataFrame) or ( - isinstance(raw_data, DataLoader) and isinstance(raw_data.data_connector, DataFrameConnector) + isinstance(raw_data, DataLoader) + and isinstance(raw_data.data_connector, DataFrameConnector) ): save_to_file = False return NDArrayLoader(save_to_file=save_to_file) diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 25371ec1..ebf9953e 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -16,7 +16,9 @@ ClusterBasedNormalizer, OneHotEncoder, ) -from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedLabelEncoder +from sdgx.models.components.sdv_rdt.transformers.categorical import ( + NormalizedLabelEncoder, +) from sdgx.utils import logger SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"]) @@ -126,11 +128,12 @@ def fit(self, data_loader: DataLoader, discrete_columns=()): # or column_name in self.metadata.label_columns logger.debug(f"Fitting discrete column {column_name}...") encoder_type = None - if self.metadata.categorical_encoder and column_name in self.metadata.categorical_encoder: + if ( + self.metadata.categorical_encoder + and column_name in self.metadata.categorical_encoder + ): encoder_type = self.metadata.categorical_encoder[column_name] - column_transform_info = self._fit_discrete( - data_loader[[column_name]], encoder_type - ) + column_transform_info = self._fit_discrete(data_loader[[column_name]], encoder_type) else: logger.debug(f"Fitting continuous column {column_name}...") column_transform_info = self._fit_continuous(data_loader[[column_name]]) diff --git a/sdgx/models/ml/single_table/ctgan.py b/sdgx/models/ml/single_table/ctgan.py index e8245aff..6e92814f 100644 --- a/sdgx/models/ml/single_table/ctgan.py +++ b/sdgx/models/ml/single_table/ctgan.py @@ -442,7 +442,7 @@ def save(self, save_dir: str | Path): return SDVBaseSynthesizer.save(self, save_dir / self.MODEL_SAVE_NAME) @classmethod - def load(cls, save_dir: str | Path, device: str=None) -> "CTGANSynthesizerModel": + def load(cls, save_dir: str | Path, device: str = None) -> "CTGANSynthesizerModel": return SDVBaseSynthesizer.load(save_dir / cls.MODEL_SAVE_NAME, device) @staticmethod