From 1813c2e4294b67a93be2e65a626e1a84ae3929d7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Dec 2024 07:09:41 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sdgx/data_connectors/dataframe_connector.py | 14 ++++++---- .../optimize/sdv_copulas/data_transformer.py | 2 +- .../optimize/sdv_ctgan/data_sampler.py | 7 ++++- .../optimize/sdv_ctgan/data_transformer.py | 27 +++++++++--------- .../components/optimize/sdv_ctgan/types.py | 28 +++++++++++-------- .../sdv_rdt/transformers/__init__.py | 4 +-- tests/test_ctgan_synthesizer.py | 6 ++-- 7 files changed, 49 insertions(+), 39 deletions(-) diff --git a/sdgx/data_connectors/dataframe_connector.py b/sdgx/data_connectors/dataframe_connector.py index ca88b161..374c5ad2 100644 --- a/sdgx/data_connectors/dataframe_connector.py +++ b/sdgx/data_connectors/dataframe_connector.py @@ -3,6 +3,7 @@ import os from functools import cached_property from typing import Callable, Generator + import pandas as pd from sdgx.data_connectors.base import DataConnector @@ -28,10 +29,10 @@ class DataFrameConnector(DataConnector): """ def __init__( - self, - df: pd.DataFrame, - *args, - **kwargs, + self, + df: pd.DataFrame, + *args, + **kwargs, ): super().__init__(*args, **kwargs) self.df: pd.DataFrame = df @@ -41,7 +42,7 @@ def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | Non if offset >= length: return None limit = limit or length - return self.df.iloc[offset: min(offset + limit, length)] + return self.df.iloc[offset : min(offset + limit, length)] def _columns(self) -> list[str]: return list(self.df.columns) @@ -52,7 +53,7 @@ def generator() -> Generator[pd.DataFrame, None, None]: if offset < length: current = offset while current < length: - yield self.df.iloc[current: min(current + chunksize, length)] + yield self.df.iloc[current : min(current + chunksize, length)] current += chunksize return generator() @@ -60,6 +61,7 @@ def generator() -> Generator[pd.DataFrame, None, None]: from sdgx.data_connectors.extension import hookimpl + @hookimpl def register(manager): manager.register("DataFrameConnector", DataFrameConnector) diff --git a/sdgx/models/components/optimize/sdv_copulas/data_transformer.py b/sdgx/models/components/optimize/sdv_copulas/data_transformer.py index 039d985e..95b64965 100644 --- a/sdgx/models/components/optimize/sdv_copulas/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_copulas/data_transformer.py @@ -8,7 +8,7 @@ ) from sdgx.models.components.sdv_rdt.transformers import ( ClusterBasedNormalizer, - FrequencyEncoder + FrequencyEncoder, ) # TODO(Enhance) - Use different type of Encoder for discrete, like ordered columns, high cardinality columns... diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py index bc8ddd3d..3c3e81b4 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py @@ -14,7 +14,12 @@ class DataSampler(object): """DataSampler samples the conditional vector and corresponding data for CTGAN.""" - def __init__(self, data: NDArrayLoader | np.ndarray, output_info: List[List[SpanInfo]], log_frequency: bool): + def __init__( + self, + data: NDArrayLoader | np.ndarray, + output_info: List[List[SpanInfo]], + log_frequency: bool, + ): self._data: NDArrayLoader | np.ndarray = data def is_onehot_encoding_column(column_info: List[SpanInfo]): diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 2164ac6b..b3e074eb 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -13,16 +13,16 @@ from sdgx.data_models.metadata import CategoricalEncoderType, Metadata from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader from sdgx.models.components.optimize.sdv_ctgan.types import ( - CategoricalEncoderInstanceType, ActivationFuncType, + CategoricalEncoderInstanceType, + ColumnTransformInfo, SpanInfo, - ColumnTransformInfo ) from sdgx.models.components.sdv_rdt.transformers import ( ClusterBasedNormalizer, - OneHotEncoder, NormalizedFrequencyEncoder, - NormalizedLabelEncoder + NormalizedLabelEncoder, + OneHotEncoder, ) from sdgx.utils import logger @@ -48,10 +48,7 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, metadata=None): self._weight_threshold = weight_threshold def _fit_categorical_encoder( - self, - column_name: str, - data: pd.DataFrame, - encoder_type: CategoricalEncoderType | None + self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType | None ) -> Tuple[CategoricalEncoderInstanceType, int, ActivationFuncType]: selected_encoder_type = None # Load encoder from metadata @@ -74,8 +71,8 @@ def _fit_categorical_encoder( # if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder. if not selected_encoder_type and self.metadata and num_categories != -1: encoder_type = ( - self.metadata.get_column_encoder_by_categorical_threshold(num_categories) - or encoder_type + self.metadata.get_column_encoder_by_categorical_threshold(num_categories) + or encoder_type ) if encoder_type == "onehot": @@ -135,7 +132,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None): """ column_name = data.columns[0] - encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type) + encoder, num_categories, activate_fn = self._fit_categorical_encoder( + column_name, data, encoder_type + ) assert encoder and activate_fn return ColumnTransformInfo( @@ -232,7 +231,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo loader = NDArrayLoader.get_auto_save(raw_data) for ndarray in tqdm.tqdm( - p(processes), desc="Transforming data", total=len(processes), delay=3 + p(processes), desc="Transforming data", total=len(processes), delay=3 ): loader.store(ndarray.astype(float)) return loader @@ -278,10 +277,10 @@ def inverse_transform(self, data, sigmas=None): column_names = [] for column_transform_info in tqdm.tqdm( - self._column_transform_info_list, desc="Inverse transforming", delay=3 + self._column_transform_info_list, desc="Inverse transforming", delay=3 ): dim = column_transform_info.output_dimensions - column_data = data[:, st: st + dim] + column_data = data[:, st : st + dim] if column_transform_info.column_type == "continuous": recovered_column_data = self._inverse_transform_continuous( column_transform_info, column_data, sigmas, st diff --git a/sdgx/models/components/optimize/sdv_ctgan/types.py b/sdgx/models/components/optimize/sdv_ctgan/types.py index e782b0fa..2aa2c1cd 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/types.py +++ b/sdgx/models/components/optimize/sdv_ctgan/types.py @@ -1,17 +1,21 @@ from __future__ import annotations -from typing import Union, Literal, List +from typing import List, Literal, Union from sdgx.models.components.sdv_rdt.transformers import ( - OneHotEncoder, ClusterBasedNormalizer, NormalizedFrequencyEncoder, - NormalizedLabelEncoder + NormalizedLabelEncoder, + OneHotEncoder, ) -CategoricalEncoderInstanceType = Union[OneHotEncoder, NormalizedFrequencyEncoder, NormalizedLabelEncoder] +CategoricalEncoderInstanceType = Union[ + OneHotEncoder, NormalizedFrequencyEncoder, NormalizedLabelEncoder +] ContinuousEncoderInstanceType = Union[ClusterBasedNormalizer] -TransformerEncoderInstanceType = Union[CategoricalEncoderInstanceType, ContinuousEncoderInstanceType] +TransformerEncoderInstanceType = Union[ + CategoricalEncoderInstanceType, ContinuousEncoderInstanceType +] ActivationFuncType = Literal["softmax", "tanh", "linear"] ColumnTransformType = Literal["discrete", "continuous"] @@ -23,12 +27,14 @@ def __init__(self, dim: int, activation_fn: ActivationFuncType): class ColumnTransformInfo: - def __init__(self, - column_name: str, - column_type: ColumnTransformType, - transform: TransformerEncoderInstanceType, - output_info: List[SpanInfo], - output_dimensions: int): + def __init__( + self, + column_name: str, + column_type: ColumnTransformType, + transform: TransformerEncoderInstanceType, + output_info: List[SpanInfo], + output_dimensions: int, + ): self.column_name: str = column_name self.column_type: str = column_type self.transform: TransformerEncoderInstanceType = transform diff --git a/sdgx/models/components/sdv_rdt/transformers/__init__.py b/sdgx/models/components/sdv_rdt/transformers/__init__.py index e011593e..a15bc928 100644 --- a/sdgx/models/components/sdv_rdt/transformers/__init__.py +++ b/sdgx/models/components/sdv_rdt/transformers/__init__.py @@ -12,9 +12,9 @@ CustomLabelEncoder, FrequencyEncoder, LabelEncoder, - OneHotEncoder, + NormalizedFrequencyEncoder, NormalizedLabelEncoder, - NormalizedFrequencyEncoder + OneHotEncoder, ) from sdgx.models.components.sdv_rdt.transformers.datetime import ( OptimizedTimestampEncoder, diff --git a/tests/test_ctgan_synthesizer.py b/tests/test_ctgan_synthesizer.py index 0a3a29d6..8557a82f 100644 --- a/tests/test_ctgan_synthesizer.py +++ b/tests/test_ctgan_synthesizer.py @@ -7,15 +7,13 @@ from sdgx.data_connectors.dataframe_connector import DataFrameConnector from sdgx.data_models.metadata import Metadata -from sdgx.models.components.optimize.sdv_ctgan.data_transformer import ( - DataTransformer, -) +from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer from sdgx.models.components.optimize.sdv_ctgan.types import SpanInfo from sdgx.models.components.sdv_rdt.transformers import ( + ClusterBasedNormalizer, NormalizedFrequencyEncoder, NormalizedLabelEncoder, OneHotEncoder, - ClusterBasedNormalizer ) from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel from sdgx.synthesizer import Synthesizer