From 226ddc55209e6195092825bb7ac0d507174ae394 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:27:48 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sdgx/data_connectors/dataframe_connector.py | 12 +-- sdgx/data_models/metadata.py | 73 ++++++++++--------- sdgx/data_processors/formatters/datetime.py | 4 +- .../components/optimize/ndarray_loader.py | 2 +- .../optimize/sdv_ctgan/data_sampler.py | 12 ++- .../optimize/sdv_ctgan/data_transformer.py | 30 +++++--- .../components/sdv_ctgan/synthesizers/base.py | 16 ++-- .../sdv_ctgan/synthesizers/ctgan.py | 5 +- .../components/sdv_ctgan/synthesizers/tvae.py | 24 +++--- .../sdv_rdt/transformers/categorical.py | 1 + sdgx/models/ml/single_table/ctgan.py | 7 +- sdgx/synthesizer.py | 11 ++- 12 files changed, 116 insertions(+), 81 deletions(-) diff --git a/sdgx/data_connectors/dataframe_connector.py b/sdgx/data_connectors/dataframe_connector.py index cd28bbba..0729fec1 100644 --- a/sdgx/data_connectors/dataframe_connector.py +++ b/sdgx/data_connectors/dataframe_connector.py @@ -11,10 +11,10 @@ class DataFrameConnector(DataConnector): def __init__( - self, - df: pd.DataFrame, - *args, - **kwargs, + self, + df: pd.DataFrame, + *args, + **kwargs, ): super().__init__(*args, **kwargs) self.df: pd.DataFrame = df @@ -24,7 +24,7 @@ def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | Non if offset >= length: return None limit = limit or length - return self.df.iloc[offset:min(offset + limit, length)] + return self.df.iloc[offset : min(offset + limit, length)] def _columns(self) -> list[str]: return list(self.df.columns) @@ -35,7 +35,7 @@ def generator() -> Generator[pd.DataFrame, None, None]: if offset < length: current = offset while current < length: - yield self.df.iloc[current:min(current + chunksize, length - 1)] + yield self.df.iloc[current : min(current + chunksize, length - 1)] current += chunksize return generator() diff --git a/sdgx/data_models/metadata.py b/sdgx/data_models/metadata.py index 1cd0e38a..32c6dac5 100644 --- a/sdgx/data_models/metadata.py +++ b/sdgx/data_models/metadata.py @@ -5,10 +5,11 @@ from collections.abc import Iterable from itertools import chain from pathlib import Path -from typing import Any, Dict, List, Set, Literal +from typing import Any, Dict, List, Literal, Set import pandas as pd from pydantic import BaseModel, Field, field_validator + from sdgx.data_loader import DataLoader from sdgx.data_models.inspectors.base import RelationshipInspector from sdgx.data_models.inspectors.manager import InspectorManager @@ -16,6 +17,8 @@ from sdgx.utils import logger CategoricalEncoderType = Literal["onehot", "label"] + + class Metadata(BaseModel): """ Metadata is mainly used to describe the data types of all columns in a single data table. @@ -87,7 +90,11 @@ def check_column_list(cls, value) -> Any: """ def check_categorical_threshold(self, num_categories): - return num_categories > self.categorical_threshold if self.categorical_threshold is not None else False + return ( + num_categories > self.categorical_threshold + if self.categorical_threshold is not None + else False + ) @property def tag_fields(self) -> Iterable[str]: @@ -115,16 +122,16 @@ def __eq__(self, other): if not isinstance(other, Metadata): return super().__eq__(other) return ( - set(self.tag_fields) == set(other.tag_fields) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.tag_fields, other.tag_fields)) - ) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.format_fields, other.format_fields)) - ) - and self.version == other.version + set(self.tag_fields) == set(other.tag_fields) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.tag_fields, other.tag_fields)) + ) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.format_fields, other.format_fields)) + ) + and self.version == other.version ) def query(self, field: str) -> Iterable[str]: @@ -184,9 +191,9 @@ def set(self, key: str, value: Any): old_value = self.get(key) if ( - key in self.model_fields - and key not in self.tag_fields - and key not in self.format_fields + key in self.model_fields + and key not in self.tag_fields + and key not in self.format_fields ): raise MetadataInitError( f"Set {key} not in tag_fields, try set it directly as m.{key} = value" @@ -276,14 +283,14 @@ def update(self, attributes: dict[str, Any]): @classmethod def from_dataloader( - cls, - dataloader: DataLoader, - max_chunk: int = 10, - primary_keys: Set[str] = None, - include_inspectors: Iterable[str] | None = None, - exclude_inspectors: Iterable[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + dataloader: DataLoader, + max_chunk: int = 10, + primary_keys: Set[str] = None, + include_inspectors: Iterable[str] | None = None, + exclude_inspectors: Iterable[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataLoader and Inspectors @@ -343,12 +350,12 @@ def from_dataloader( @classmethod def from_dataframe( - cls, - df: pd.DataFrame, - include_inspectors: list[str] | None = None, - exclude_inspectors: list[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + df: pd.DataFrame, + include_inspectors: list[str] | None = None, + exclude_inspectors: list[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataFrame and Inspectors @@ -532,10 +539,10 @@ def get_column_data_type(self, column_name: str): # find the dtype who has most high inspector level for each_key in list(self.model_fields.keys()) + list(self._extend.keys()): if ( - each_key != "pii_columns" - and each_key.endswith("_columns") - and column_name in self.get(each_key) - and current_level < self.column_inspect_level[each_key] + each_key != "pii_columns" + and each_key.endswith("_columns") + and column_name in self.get(each_key) + and current_level < self.column_inspect_level[each_key] ): current_level = self.column_inspect_level[each_key] current_type = each_key diff --git a/sdgx/data_processors/formatters/datetime.py b/sdgx/data_processors/formatters/datetime.py index e14cf468..8ccd8241 100644 --- a/sdgx/data_processors/formatters/datetime.py +++ b/sdgx/data_processors/formatters/datetime.py @@ -134,7 +134,9 @@ def datetime_formatter(each_value, datetime_format): datetime_obj = datetime.strptime(str(each_value), datetime_format) each_stamp = datetime.timestamp(datetime_obj) except Exception as e: - logger.warning(f"An error occured when convert str to timestamp {e}, we set as mean.") + logger.warning( + f"An error occured when convert str to timestamp {e}, we set as mean." + ) logger.warning(f"Input parameters: ({str(each_value)}, {datetime_format})") logger.warning(f"Input type: ({type(each_value)}, {type(datetime_format)})") each_stamp = np.nan diff --git a/sdgx/models/components/optimize/ndarray_loader.py b/sdgx/models/components/optimize/ndarray_loader.py index 0cac9d4c..7c46e910 100644 --- a/sdgx/models/components/optimize/ndarray_loader.py +++ b/sdgx/models/components/optimize/ndarray_loader.py @@ -55,7 +55,7 @@ def store(self, ndarray: ndarray): np.save(self._get_cache_filename(self.store_index), ndarray) self.store_index += 1 else: - for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1): + for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1): self.ndarray_list.append(ndarray) self.store_index += 1 diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py index 88287ac7..86377b5b 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py @@ -47,7 +47,11 @@ def is_onehot_encoding_column(column_info: SpanInfo): # Prepare an interval matrix for efficiently sample conditional vector max_category = max( - [column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)], + [ + column_info[0].dim + for column_info in output_info + if is_onehot_encoding_column(column_info) + ], default=0, ) @@ -56,7 +60,11 @@ def is_onehot_encoding_column(column_info: SpanInfo): self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category)) self._n_discrete_columns = n_discrete_columns self._n_categories = sum( - [column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)] + [ + column_info[0].dim + for column_info in output_info + if is_onehot_encoding_column(column_info) + ] ) st = 0 diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 69d9795c..9f776bd6 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -10,13 +10,16 @@ from tqdm import autonotebook as tqdm from sdgx.data_loader import DataLoader -from sdgx.data_models.metadata import Metadata, CategoricalEncoderType +from sdgx.data_models.metadata import CategoricalEncoderType, Metadata from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader from sdgx.models.components.sdv_rdt.transformers import ( ClusterBasedNormalizer, OneHotEncoder, ) -from sdgx.models.components.sdv_rdt.transformers.categorical import LabelEncoder, NormalizedLabelEncoder +from sdgx.models.components.sdv_rdt.transformers.categorical import ( + LabelEncoder, + NormalizedLabelEncoder, +) from sdgx.utils import logger SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"]) @@ -66,7 +69,7 @@ def _fit_continuous(self, data): column_name=column_name, column_type="continuous", transform=gm, - output_info=[SpanInfo(1, "tanh"), SpanInfo(num_components, "softmax")], + output_info=[SpanInfo(1, "tanh"), SpanInfo(num_components, "softmax")], output_dimensions=1 + num_components, ) @@ -89,9 +92,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None): activate_fn = "softmax" checked = self.metadata.check_categorical_threshold(num_categories) - if encoder_type == 'onehot' or not checked: + if encoder_type == "onehot" or not checked: pass - elif encoder_type == 'label': + elif encoder_type == "label": encoder = NormalizedLabelEncoder(order_by="alphabetical") encoder.fit(data, column_name) num_categories = 1 @@ -126,9 +129,14 @@ def fit(self, data_loader: DataLoader, discrete_columns=()): # or column_name in self.metadata.label_columns logger.debug(f"Fitting discrete column {column_name}...") - column_transform_info = self._fit_discrete(data_loader[[column_name]], - self.metadata.categorical_encoder[ - column_name] if column_name in self.metadata.categorical_encoder else 'onehot') + column_transform_info = self._fit_discrete( + data_loader[[column_name]], + ( + self.metadata.categorical_encoder[column_name] + if column_name in self.metadata.categorical_encoder + else "onehot" + ), + ) else: logger.debug(f"Fitting continuous column {column_name}...") column_transform_info = self._fit_continuous(data_loader[[column_name]]) @@ -198,7 +206,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo loader = NDArrayLoader(save_to_file=False) for ndarray in tqdm.tqdm( - p(processes), desc="Transforming data", total=len(processes), delay=3 + p(processes), desc="Transforming data", total=len(processes), delay=3 ): loader.store(ndarray.astype(float)) return loader @@ -244,10 +252,10 @@ def inverse_transform(self, data, sigmas=None): column_names = [] for column_transform_info in tqdm.tqdm( - self._column_transform_info_list, desc="Inverse transforming", delay=3 + self._column_transform_info_list, desc="Inverse transforming", delay=3 ): dim = column_transform_info.output_dimensions - column_data = data[:, st: st + dim] + column_data = data[:, st : st + dim] if column_transform_info.column_type == "continuous": recovered_column_data = self._inverse_transform_continuous( column_transform_info, column_data, sigmas, st diff --git a/sdgx/models/components/sdv_ctgan/synthesizers/base.py b/sdgx/models/components/sdv_ctgan/synthesizers/base.py index 8ebf4f32..4ddfafae 100644 --- a/sdgx/models/components/sdv_ctgan/synthesizers/base.py +++ b/sdgx/models/components/sdv_ctgan/synthesizers/base.py @@ -76,9 +76,9 @@ def __getstate__(self): random_states = self.random_states if ( - isinstance(random_states, tuple) - and isinstance(random_states[0], np.random.RandomState) - and isinstance(random_states[1], torch.Generator) + isinstance(random_states, tuple) + and isinstance(random_states[0], np.random.RandomState) + and isinstance(random_states[1], torch.Generator) ): state["_numpy_random_state"] = random_states[0].get_state() state["_torch_random_state"] = random_states[1].get_state() @@ -112,7 +112,9 @@ def save(self, path): self.set_device(device_backup) @classmethod - def load(cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu"): + def load( + cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu" + ): """Load the model stored in the passed arg `path`.""" with open(path, "rb") as f: model = cloudpickle.load(f) @@ -135,9 +137,9 @@ def set_random_state(self, random_state): torch.Generator().manual_seed(random_state), ) elif ( - isinstance(random_state, tuple) - and isinstance(random_state[0], np.random.RandomState) - and isinstance(random_state[1], torch.Generator) + isinstance(random_state, tuple) + and isinstance(random_state[0], np.random.RandomState) + and isinstance(random_state[1], torch.Generator) ): self.random_states = random_state else: diff --git a/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py b/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py index 3629942f..6bb1d8b3 100644 --- a/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py +++ b/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py @@ -22,7 +22,8 @@ from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer from sdgx.models.components.sdv_ctgan.synthesizers.base import ( BaseSynthesizer, - random_state, BatchedSynthesizer, + BatchedSynthesizer, + random_state, ) @@ -262,7 +263,7 @@ def _cond_loss(self, data, c, m): st_c = 0 for column_info in self._transformer.output_info_list: for span_info in column_info: - if len(column_info) != 1 or span_info.activation_fn != "softmax": # todo 待修改 + if len(column_info) != 1 or span_info.activation_fn != "softmax": # todo 待修改 # not discrete column st += span_info.dim else: diff --git a/sdgx/models/components/sdv_ctgan/synthesizers/tvae.py b/sdgx/models/components/sdv_ctgan/synthesizers/tvae.py index 2f129514..997a0201 100644 --- a/sdgx/models/components/sdv_ctgan/synthesizers/tvae.py +++ b/sdgx/models/components/sdv_ctgan/synthesizers/tvae.py @@ -9,8 +9,8 @@ from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer from sdgx.models.components.sdv_ctgan.synthesizers.base import ( + BatchedSynthesizer, random_state, - BatchedSynthesizer ) @@ -85,7 +85,7 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor): ed = st + span_info.dim std = sigmas[st] eq = x[:, st] - torch.tanh(recon_x[:, st]) - loss.append((eq ** 2 / 2 / (std ** 2)).sum()) + loss.append((eq**2 / 2 / (std**2)).sum()) loss.append(torch.log(std) * x.size()[0]) st = ed @@ -99,7 +99,7 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor): st = ed assert st == recon_x.size()[1] - KLD = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp()) + KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp()) return sum(loss) * factor / x.size()[0], KLD / x.size()[0] @@ -107,15 +107,15 @@ class TVAE(BatchedSynthesizer): """TVAE.""" def __init__( - self, - embedding_dim=128, - compress_dims=(128, 128), - decompress_dims=(128, 128), - l2scale=1e-5, - batch_size=500, - epochs=300, - loss_factor=2, - cuda=True, + self, + embedding_dim=128, + compress_dims=(128, 128), + decompress_dims=(128, 128), + l2scale=1e-5, + batch_size=500, + epochs=300, + loss_factor=2, + cuda=True, ): super().__init__(batch_size) self.embedding_dim = embedding_dim diff --git a/sdgx/models/components/sdv_rdt/transformers/categorical.py b/sdgx/models/components/sdv_rdt/transformers/categorical.py index ab0ca54e..2d8e1c6c 100644 --- a/sdgx/models/components/sdv_rdt/transformers/categorical.py +++ b/sdgx/models/components/sdv_rdt/transformers/categorical.py @@ -1,4 +1,5 @@ """Transformers for categorical data.""" + import math import warnings diff --git a/sdgx/models/ml/single_table/ctgan.py b/sdgx/models/ml/single_table/ctgan.py index 168189b9..88068c3b 100644 --- a/sdgx/models/ml/single_table/ctgan.py +++ b/sdgx/models/ml/single_table/ctgan.py @@ -1,5 +1,5 @@ from __future__ import annotations -from tqdm import autonotebook as tqdm + import time from pathlib import Path from typing import List @@ -19,6 +19,7 @@ Sequential, functional, ) +from tqdm import autonotebook as tqdm from sdgx.data_loader import DataLoader from sdgx.data_models.metadata import Metadata @@ -221,7 +222,9 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, ** self._fit(len(self._ndarry_loader)) logger.info("CTGAN training finished.") - def _pre_fit(self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None) -> NDArrayLoader: + def _pre_fit( + self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None + ) -> NDArrayLoader: if not discrete_columns: discrete_columns = [] diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index ff8006d0..62b1a009 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -3,8 +3,9 @@ import time from pathlib import Path from typing import Any, Generator -from tqdm import autonotebook as tqdm + import pandas as pd +from tqdm import autonotebook as tqdm from sdgx.data_connectors.base import DataConnector from sdgx.data_connectors.generator_connector import GeneratorConnector @@ -190,7 +191,7 @@ def load( processed_data_loaders_kwargs: None | dict[str, Any] = None, data_processors: None | list[str | DataProcessor | type[DataProcessor]] = None, data_processors_kwargs: None | dict[str, dict[str, Any]] = None, - model_kwargs = None + model_kwargs=None, ) -> "Synthesizer": """ Load metadata and model, allow rebuilding Synthesizer for finetuning or other use cases. @@ -321,7 +322,7 @@ def chunk_generator() -> Generator[pd.DataFrame, None, None]: logger.info(f"Initialized processed data loader in {time.time() - start_time}s") try: logger.info("Model fit Started...") - self.model.fit(metadata, processed_dataloader,**(model_fit_kwargs or {})) + self.model.fit(metadata, processed_dataloader, **(model_fit_kwargs or {})) logger.info("Model fit... Finished") finally: processed_dataloader.finalize(clear_cache=True) @@ -402,7 +403,9 @@ def _sample_once( batch_size = self.model.get_batch_size() while missing_count > 0 and max_trails > 0: - sample_data = self.model.sample(max(int(missing_count * 1.2), batch_size), **model_sample_args) + sample_data = self.model.sample( + max(int(missing_count * 1.2), batch_size), **model_sample_args + ) # TODO table separated parallel process for d in self.data_processors: sample_data = d.reverse_convert(sample_data)