diff --git a/example/sdgx_example_metadata.ipynb b/example/sdgx_example_metadata.ipynb index 3143c96d..86b35f44 100644 --- a/example/sdgx_example_metadata.ipynb +++ b/example/sdgx_example_metadata.ipynb @@ -980,7 +980,7 @@ " \"industry\": \"frequency\",\n", " \"work_year\": \"label\"\n", " # \"work_type\" using default encoder, we are not specified its encoder.\n", - " # \"issue_date\" and \"earlies_credit_mon\" are datetime columns. We are not specified its encoder.\n", + " # \"issue_date\" and \"earlies_credit_mon\" are datetime columns. We not need to specify its encoder when we use DatetimeFormatter in training, because it transformed as float. \n", "}" ], "outputs": [], diff --git a/sdgx/data_models/metadata.py b/sdgx/data_models/metadata.py index cad4caae..9efc824c 100644 --- a/sdgx/data_models/metadata.py +++ b/sdgx/data_models/metadata.py @@ -90,7 +90,7 @@ def check_column_list(cls, value) -> Any: """ def get_column_encoder_by_categorical_threshold( - self, num_categories: int + self, num_categories: int ) -> Union[CategoricalEncoderType, None]: encoder_type = None if self.categorical_threshold is None: @@ -135,16 +135,16 @@ def __eq__(self, other): if not isinstance(other, Metadata): return super().__eq__(other) return ( - set(self.tag_fields) == set(other.tag_fields) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.tag_fields, other.tag_fields)) - ) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.format_fields, other.format_fields)) - ) - and self.version == other.version + set(self.tag_fields) == set(other.tag_fields) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.tag_fields, other.tag_fields)) + ) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.format_fields, other.format_fields)) + ) + and self.version == other.version ) def query(self, field: str) -> Iterable[str]: @@ -204,9 +204,9 @@ def set(self, key: str, value: Any): old_value = self.get(key) if ( - key in self.model_fields - and key not in self.tag_fields - and key not in self.format_fields + key in self.model_fields + and key not in self.tag_fields + and key not in self.format_fields ): raise MetadataInitError( f"Set {key} not in tag_fields, try set it directly as m.{key} = value" @@ -215,7 +215,13 @@ def set(self, key: str, value: Any): if isinstance(old_value, Iterable) and not isinstance(old_value, str): # list, set, tuple... value = value if isinstance(value, Iterable) and not isinstance(value, str) else [value] - value = type(old_value)(value) + try: + value = type(old_value)(value) + except TypeError as e: + if type(old_value) == defaultdict: + value = dict(value) + else: + raise e if key in self.model_fields: setattr(self, key, value) @@ -296,14 +302,14 @@ def update(self, attributes: dict[str, Any]): @classmethod def from_dataloader( - cls, - dataloader: DataLoader, - max_chunk: int = 10, - primary_keys: Set[str] = None, - include_inspectors: Iterable[str] | None = None, - exclude_inspectors: Iterable[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + dataloader: DataLoader, + max_chunk: int = 10, + primary_keys: Set[str] = None, + include_inspectors: Iterable[str] | None = None, + exclude_inspectors: Iterable[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataLoader and Inspectors @@ -363,12 +369,12 @@ def from_dataloader( @classmethod def from_dataframe( - cls, - df: pd.DataFrame, - include_inspectors: list[str] | None = None, - exclude_inspectors: list[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + df: pd.DataFrame, + include_inspectors: list[str] | None = None, + exclude_inspectors: list[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataFrame and Inspectors @@ -552,10 +558,10 @@ def get_column_data_type(self, column_name: str): # find the dtype who has most high inspector level for each_key in list(self.model_fields.keys()) + list(self._extend.keys()): if ( - each_key != "pii_columns" - and each_key.endswith("_columns") - and column_name in self.get(each_key) - and current_level < self.column_inspect_level[each_key] + each_key != "pii_columns" + and each_key.endswith("_columns") + and column_name in self.get(each_key) + and current_level < self.column_inspect_level[each_key] ): current_level = self.column_inspect_level[each_key] current_type = each_key @@ -575,3 +581,77 @@ def get_column_pii(self, column_name: str): if column_name in self.pii_columns: return True return False + + def change_column_type(self, column_names: str | List[str], column_original_type: str, column_new_type: str): + """Change the type of column.""" + if not column_names: + return + if isinstance(column_names, str): + column_names = [column_names] + all_fields = list(self.tag_fields) + original_type = f"{column_original_type}_columns" + new_type = f"{column_new_type}_columns" + if original_type not in all_fields: + raise MetadataInvalidError(f"Column type {column_original_type} not exist in metadata.") + if new_type not in all_fields: + raise MetadataInvalidError(f"Column type {column_new_type} not exist in metadata.") + type_columns = self.get(original_type) + diff = set(column_names).difference(type_columns) + if diff: + raise MetadataInvalidError(f"Columns {column_names} not exist in {original_type}.") + self.add(new_type, column_names) + type_columns = type_columns.difference(column_names) + self.set(original_type, type_columns) + + def remove_column(self, column_names: List[str] | str): + """ + Remove a column from all columns type. + Args: + column_names: List[str]: To removed columns name list. + """ + if not column_names: + return + if isinstance(column_names, str): + column_names = [column_names] + column_names = frozenset(column_names) + inter = column_names.intersection(self.column_list) + if not inter: + raise MetadataInvalidError(f"Columns {inter} not exist in metadata.") + + def do_remove_columns(key, get=True, to_removes=column_names): + obj = self + if get: + target = obj.get(key) + else: + target = getattr(obj, key) + res = None + if isinstance(target, list): + res = [item for item in target if item not in to_removes] + elif isinstance(target, dict): + if key == "numeric_format": + obj.set(key, { + k: { + v2 for v2 in v if v2 not in to_removes + } for k, v in target.items() + }) + else: + res = { + k: v for k, v in target.items() + if k not in to_removes + } + elif isinstance(target, set): + res = target.difference(to_removes) + + if res is not None: + if get: + obj.set(key, res) + else: + setattr(obj, key, res) + + to_remove_attribute = list(self.tag_fields) + to_remove_attribute.extend(list(self.format_fields)) + for attr in to_remove_attribute: + do_remove_columns(attr) + for attr in ["column_list", "primary_keys"]: + do_remove_columns(attr, False) + self.check() diff --git a/sdgx/data_processors/formatters/datetime.py b/sdgx/data_processors/formatters/datetime.py index 8ccd8241..76735d41 100644 --- a/sdgx/data_processors/formatters/datetime.py +++ b/sdgx/data_processors/formatters/datetime.py @@ -78,6 +78,12 @@ def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): f"Column {each_col} has no datetime_format, DatetimeFormatter will REMOVE this column!" ) + # Remove successful formatted datetime columns from metadata.discrete_columns + if not (set(datetime_columns) - set(metadata.discrete_columns)): + metadata.change_column_type(datetime_columns, "discrete", "datetime") + # Remove dead_columns from metadata + metadata.remove_column(dead_columns) + self.datetime_columns = datetime_columns self.dead_columns = dead_columns @@ -189,25 +195,17 @@ def convert_timestamp_to_datetime(timestamp_column_list, format_dict, processed_ Returns: - result_data (pd.DataFrame): DataFrame with timestamp columns converted to datetime format. - """ - - def convert_single_column_timestamp_to_str(column_data: pd.Series, datetime_format: str): - """ - convert each single column timestamp(int) to datetime string. - """ - res = [] - for each_stamp in column_data: - try: - each_str = datetime.fromtimestamp(each_stamp).strftime(datetime_format) - except Exception as e: - logger.debug(f"An error occured when convert timestamp to str {e}.") - each_str = "No Datetime" - - res.append(each_str) - res = pd.Series(res) - res = res.astype(str) - return res + TODO: + if the value <0, the result will be `No Datetime`, try to fix it. + """ + def column_timestamp_formatter(each_stamp: int, timestamp_format: str) -> str: + try: + each_str = datetime.fromtimestamp(each_stamp).strftime(timestamp_format) + except Exception as e: + logger.debug(f"An error occured when convert timestamp to str {e}.") + each_str = "No Datetime" + return each_str # Copy the processed data to result_data result_data = processed_data.copy() @@ -217,8 +215,9 @@ def convert_single_column_timestamp_to_str(column_data: pd.Series, datetime_form # Check if the column is in the DataFrame if column in result_data.columns: # Convert the timestamp to datetime format using the format provided in datetime_column_dict - result_data[column] = convert_single_column_timestamp_to_str( - result_data[column], format_dict[column] + result_data[column] = result_data[column].apply( + column_timestamp_formatter, + timestamp_format=format_dict[column] ) else: logger.error(f"Column {column} not in processed data's column list!") diff --git a/sdgx/models/components/optimize/ndarray_loader.py b/sdgx/models/components/optimize/ndarray_loader.py index 7a537e8f..7d34be89 100644 --- a/sdgx/models/components/optimize/ndarray_loader.py +++ b/sdgx/models/components/optimize/ndarray_loader.py @@ -98,8 +98,12 @@ def get_all(self) -> ndarray: return np.concatenate([array for array in self.iter()], axis=1) @cached_property + def __shape_0(self): + return self.load(0).shape[0] + + @property def shape(self) -> tuple[int, int]: - return self.load(0).shape[0], self.store_index + return self.__shape_0, self.store_index def __len__(self): return self.shape[0] diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py index 86377b5b..2263e8df 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py @@ -86,6 +86,7 @@ def is_onehot_encoding_column(column_info: SpanInfo): st = ed else: st += sum([span_info.dim for span_info in column_info]) + assert st == data.shape[1] def _random_choice_prob_index(self, discrete_column_id): probs = self._discrete_column_category_prob[discrete_column_id] diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 5ccd05f7..6bf917ef 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import namedtuple +from typing import List import numpy as np import pandas as pd @@ -102,8 +103,8 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None): encoder.fit(data, column_name) num_categories = len(encoder.dummies) activate_fn = "softmax" - # if selected_encoder_type is not specified or using onehot num_categories > threshold, change the encoder. - if not selected_encoder_type or self.metadata and num_categories != -1: + # if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder. + if not selected_encoder_type and self.metadata and num_categories != -1: encoder_type = ( self.metadata.get_column_encoder_by_categorical_threshold(num_categories) or encoder_type @@ -145,12 +146,12 @@ def fit(self, data_loader: DataLoader, discrete_columns=()): This step also counts the #columns in matrix data and span information. """ - self.output_info_list = [] - self.output_dimensions = 0 - self.dataframe = True + self.output_info_list: List[SpanInfo] = [] + self.output_dimensions: int = 0 + self.dataframe: bool = True self._column_raw_dtypes = data_loader[: data_loader.chunksize].infer_objects().dtypes - self._column_transform_info_list = [] + self._column_transform_info_list: List[ColumnTransformInfo] = [] for column_name in tqdm.tqdm(data_loader.columns(), desc="Preparing data", delay=3): if column_name in discrete_columns: # or column_name in self.metadata.label_columns @@ -183,8 +184,8 @@ def _transform_continuous(self, column_transform_info, data): def _transform_discrete(self, column_transform_info, data): logger.debug(f"Transforming discrete column {column_transform_info.column_name}...") - ohe = column_transform_info.transform - return ohe.transform(data).to_numpy() + encoder = column_transform_info.transform + return encoder.transform(data).to_numpy() def _synchronous_transform(self, raw_data, column_transform_info_list) -> NDArrayLoader: """Take a Pandas DataFrame and transform columns synchronous. diff --git a/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py b/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py index 0cc52fd9..542bb5ec 100644 --- a/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py +++ b/sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py @@ -263,8 +263,7 @@ def _cond_loss(self, data, c, m): for column_info in self._transformer.output_info_list: for span_info in column_info: if len(column_info) != 1 or span_info.activation_fn != "softmax": - # TODO may need to revise for label encoder? - # not discrete column + # not onehot column st += span_info.dim else: ed = st + span_info.dim diff --git a/sdgx/models/components/sdv_rdt/transformers/categorical.py b/sdgx/models/components/sdv_rdt/transformers/categorical.py index e8045799..c132b90a 100644 --- a/sdgx/models/components/sdv_rdt/transformers/categorical.py +++ b/sdgx/models/components/sdv_rdt/transformers/categorical.py @@ -231,7 +231,7 @@ def _reverse_transform_by_row(self, data): """Reverse transform the data by iterating over each row.""" return data.apply(self._get_category_from_start).astype(self.dtype) - def _reverse_transform(self, data): + def _reverse_transform(self, data, normalize=False): """Convert float values back to the original categorical values. Args: @@ -241,7 +241,7 @@ def _reverse_transform(self, data): Returns: pandas.Series """ - data = data.clip(0, 1) + data = data.clip(-1 if normalize else 0, 1) num_rows = len(data) num_categories = len(self.means) @@ -675,3 +675,6 @@ def _fit(self, data): """ self.dtype = data.dtype self.intervals, self.means, self.starts = self._get_intervals(data, normalized=True) + + def _reverse_transform(self, data): + return super()._reverse_transform(data, True) diff --git a/sdgx/models/ml/single_table/ctgan.py b/sdgx/models/ml/single_table/ctgan.py index 6e92814f..76ea1b4e 100644 --- a/sdgx/models/ml/single_table/ctgan.py +++ b/sdgx/models/ml/single_table/ctgan.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import time from pathlib import Path from typing import List @@ -27,7 +28,7 @@ from sdgx.models.components.optimize.sdv_ctgan.data_sampler import DataSampler from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer from sdgx.models.components.sdv_ctgan.synthesizers.base import ( - BaseSynthesizer as SDVBaseSynthesizer, + BaseSynthesizer as SDVBaseSynthesizer, BatchedSynthesizer, ) from sdgx.models.components.sdv_ctgan.synthesizers.base import random_state from sdgx.models.extension import hookimpl @@ -117,7 +118,7 @@ def forward(self, input_): return data -class CTGANSynthesizerModel(MLSynthesizerModel, SDVBaseSynthesizer): +class CTGANSynthesizerModel(MLSynthesizerModel, BatchedSynthesizer): """ Modified from ``sdgx.models.components.sdv_ctgan.synthesizers.ctgan.CTGANSynthesizer``. A CTGANSynthesizer but provided :ref:`SynthesizerModel` interface with chunked fit. @@ -182,7 +183,7 @@ def __init__( device="cuda" if torch.cuda.is_available() else "cpu", ): assert batch_size % 2 == 0 - + BatchedSynthesizer.__init__(self, batch_size=batch_size) self._embedding_dim = embedding_dim self._generator_dim = generator_dim self._discriminator_dim = discriminator_dim @@ -192,7 +193,6 @@ def __init__( self._discriminator_lr = discriminator_lr self._discriminator_decay = discriminator_decay - self._batch_size = batch_size self._discriminator_steps = discriminator_steps self._log_frequency = log_frequency self._epochs = epochs @@ -201,11 +201,11 @@ def __init__( self._device = torch.device(device) # Following components are initialized in `_pre_fit` - self._transformer = None - self._data_sampler = None + self._transformer: DataTransformer = None + self._data_sampler: DataSampler = None self._generator = None - self._ndarry_loader = None - self.data_dim = None + self._ndarry_loader: NDArrayLoader = None + self.data_dim: int = None def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **kwargs): # In the future, sdgx use `sdgx.data_processor.transformers.discrete` to handle discrete_columns @@ -224,7 +224,7 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, ** def _pre_fit( self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None - ) -> NDArrayLoader: + ): if not discrete_columns: discrete_columns = [] @@ -408,7 +408,7 @@ def _sample(self, n, condition_column=None, condition_value=None, drop_more=True else: global_condition_vec = None - steps = n // self._batch_size + 1 + steps = math.ceil(n / self._batch_size) data = [] for _ in tqdm.tqdm(range(steps), desc="Sampling batches", delay=3): mean = torch.zeros(self._batch_size, self._embedding_dim) diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index 5776c3c9..55718263 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -18,6 +18,7 @@ from sdgx.models.base import SynthesizerModel from sdgx.models.components.sdv_ctgan.synthesizers.base import BatchedSynthesizer from sdgx.models.manager import ModelManager +from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel from sdgx.models.statistics.single_table.base import StatisticSynthesizerModel from sdgx.utils import logger @@ -291,7 +292,8 @@ def fit( inspector_init_kwargs=inspector_init_kwargs, ) ) - self.metadata = metadata # Ensure update metadata + # Some processors may cause metadata update before model fitting, we need to make a copy. + self.metadata = metadata.model_copy() # Ensure update metadata logger.info("Fitting data processors...") if not self.dataloader: @@ -403,6 +405,10 @@ def _sample_once( if isinstance(self.model, BatchedSynthesizer): batch_size = self.model.get_batch_size() multiply_factor = 1.2 + if isinstance(self.model, CTGANSynthesizerModel): + model_sample_args = { + "drop_more": False + } while missing_count > 0 and max_trails > 0: sample_data = self.model.sample( diff --git a/tests/data_models/test_metadata.py b/tests/data_models/test_metadata.py index f3a82adb..009f39c3 100644 --- a/tests/data_models/test_metadata.py +++ b/tests/data_models/test_metadata.py @@ -30,10 +30,28 @@ def test_metadata(metadata: Metadata): assert metadata.float_columns == metadata.get("float_columns") metadata.set("a", "something") - assert metadata.get("a") == set(["something"]) + assert metadata.get("a") == {"something"} assert metadata._dump_json() +def test_change_metadata(metadata: Metadata): + metadata = metadata.model_copy() + col = "age" + assert col in metadata.int_columns + assert col not in metadata.datetime_columns + metadata.change_column_type(col, "int", "datetime") + assert col in metadata.datetime_columns + assert col not in metadata.int_columns + metadata.change_column_type(col, "datetime", "int") + assert col in metadata.int_columns + assert col not in metadata.datetime_columns + +def test_remove_metadata(metadata: Metadata): + metadata = metadata.model_copy() + col = "age" + assert col in metadata.int_columns + metadata.remove_column([col]) + assert col not in metadata.int_columns def test_metadata_save_load(metadata: Metadata, tmp_path: Path): test_path = tmp_path / "metadata_path_test.json" @@ -48,7 +66,7 @@ def test_metadata_primary_key(metadata: Metadata): metadata.add("id_columns", "fnlwgt") # set fnlwgt as primary key metadata.update_primary_key(["fnlwgt"]) - assert metadata.primary_keys == set(["fnlwgt"]) + assert metadata.primary_keys == {"fnlwgt"} def test_metadata_primary_query_filed_tags(): diff --git a/tests/models/test_nfrequencyencoder.py b/tests/models/test_nfrequencyencoder.py new file mode 100644 index 00000000..41a3255e --- /dev/null +++ b/tests/models/test_nfrequencyencoder.py @@ -0,0 +1,37 @@ +import pandas as pd +import pytest + +from sdgx.models.components.sdv_rdt.transformers.categorical import ( + NormalizedFrequencyEncoder, +) + + +@pytest.fixture(scope="module") +def data_test(): + return pd.DataFrame( + { + "x": [str(i) for i in range(100)], + "y": [str(-i) for i in range(50)] * 2, + "z": [str(i) for i in range(25)] * 4, + }, + columns=["x", "y", "z"], + ) + + +def test_encoder(data_test: pd.DataFrame): + + for col in ["x", "y", "z"]: + nlabel_encoder = NormalizedFrequencyEncoder() + nlabel_encoder.fit(data_test, col) + td = nlabel_encoder.transform(data_test.copy()) + rd = nlabel_encoder.reverse_transform(td.copy()) + td.rename(columns={f"{col}.value": f"{col}"}, inplace=True) + assert (rd[col].sort_values().values == data_test[col].sort_values().values).all() + assert (td[col] >= -1).all() + assert (td[col] <= 1).all() + assert td[col].shape == data_test[col].shape + assert len(td[col].unique()) == len(data_test[col].unique()) + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/models/test_nlabelencoder.py b/tests/models/test_nlabelencoder.py index ca01ca1f..0ec0c5ae 100644 --- a/tests/models/test_nlabelencoder.py +++ b/tests/models/test_nlabelencoder.py @@ -28,9 +28,9 @@ def test_encoder(data_test: pd.DataFrame): td = nlabel_encoder.transform(data_test.copy()) rd = nlabel_encoder.reverse_transform(td.copy()) td.rename(columns={f"{col}.value": f"{col}"}, inplace=True) - assert (rd[col].sort_values().values == data_test[col].sort_values().values).any() - assert (td[col] >= 0).any() - assert (td[col] <= 1).any() + assert (rd[col].sort_values().values == data_test[col].sort_values().values).all() + assert (td[col] >= -1).all() + assert (td[col] <= 1).all() assert td[col].shape == data_test[col].shape assert len(td[col].unique()) == len(data_test[col].unique()) diff --git a/tests/test_ctgan_synthesizer.py b/tests/test_ctgan_synthesizer.py index 798ab2da..adeb8cdc 100644 --- a/tests/test_ctgan_synthesizer.py +++ b/tests/test_ctgan_synthesizer.py @@ -1,96 +1,133 @@ +from typing import List + +import faker import numpy as np import pandas as pd import pytest -from sdgx.data_connectors.csv_connector import CsvConnector -from sdgx.data_loader import DataLoader +from sdgx.data_connectors.dataframe_connector import DataFrameConnector from sdgx.data_models.metadata import Metadata +from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer, SpanInfo +from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedFrequencyEncoder, NormalizedLabelEncoder, \ + OneHotEncoder +from sdgx.models.components.sdv_rdt.transformers.numerical import ClusterBasedNormalizer from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel from sdgx.synthesizer import Synthesizer @pytest.fixture def demo_single_table_data_pos_neg(): - row_cnt = 1000 - header = ["int_id", "pos_int", "neg_int", "pos_float", "neg_float", "mixed_int", "mixed_float"] - + row_cnt = 1000 # must be 200 multiply because of the encode setting np.random.seed(42) - int_id = list(range(row_cnt)) - pos_int = np.random.randint(1, 100, size=row_cnt) - neg_int = np.random.randint(-100, 0, size=row_cnt) - pos_float = np.random.uniform(0, 100, size=row_cnt) - neg_float = np.random.uniform(-100, 0, size=row_cnt) - mixed_int = np.random.randint(-50, 50, size=row_cnt) - mixed_float = np.random.uniform(-50, 50, size=row_cnt) - - X = [ - [ - int_id[i], - pos_int[i], - neg_int[i], - pos_float[i], - neg_float[i], - mixed_int[i], - mixed_float[i], - ] - for i in range(row_cnt) - ] - - yield pd.DataFrame(X, columns=header) + faker.Faker.seed(42) + fake = faker.Faker() + X = { + "int_id": list(range(row_cnt)), + "pos_int": np.random.randint(1, 100, size=row_cnt), + "neg_int": np.random.randint(-100, 0, size=row_cnt), + "pos_float": np.random.uniform(0, 100, size=row_cnt), + "neg_float": np.random.uniform(-100, 0, size=row_cnt), + "mixed_int": np.random.randint(-50, 50, size=row_cnt), + "mixed_float": np.random.uniform(-50, 50, size=row_cnt), + "cat_onehot": [str(i) for i in range(row_cnt)], + "cat_label": [str(i) for i in range(row_cnt)], + "cat_date": [fake.date() for _ in range(row_cnt)], + "cat_freq": [str(i) for i in range(row_cnt)], + "cat_thres_freq": [str(i) for i in range(100)] * (row_cnt // 100), + "cat_thres_label": [str(i) for i in range(200)] * (row_cnt // 200) + } + header = X.keys() + yield pd.DataFrame(X, columns=list(header)) @pytest.fixture def demo_single_table_data_pos_neg_metadata(demo_single_table_data_pos_neg): - yield Metadata.from_dataframe(demo_single_table_data_pos_neg.copy(), check=True) - - -@pytest.fixture -def demo_single_table_data_pos_neg_path(tmp_path, demo_single_table_data_pos_neg): - df = demo_single_table_data_pos_neg - save_path = tmp_path / "dummy_demo_single_table_data_pos_neg.csv" - df.to_csv(save_path, index=False, header=True) - yield save_path - save_path.unlink() + metadata = Metadata.from_dataframe(demo_single_table_data_pos_neg.copy(), check=True) + metadata.categorical_encoder = { + "cat_onehot": "onehot", + "cat_label": "label", + "cat_freq": "frequency" + } + metadata.datetime_format = { + "cat_date": "%Y-%m-%d" + } + metadata.categorical_threshold = { + 99: "frequency", + 199: "label" + } + yield metadata @pytest.fixture -def demo_single_table_data_pos_neg_connector(demo_single_table_data_pos_neg_path): - yield CsvConnector( - path=demo_single_table_data_pos_neg_path, +def demo_single_table_data_pos_neg_connector(demo_single_table_data_pos_neg): + yield DataFrameConnector( + df=demo_single_table_data_pos_neg ) @pytest.fixture -def demo_single_table_data_pos_neg_loader(demo_single_table_data_pos_neg_connector, cacher_kwargs): - d = DataLoader(demo_single_table_data_pos_neg_connector, cacher_kwargs=cacher_kwargs) - yield d - d.finalize() - - -@pytest.fixture -def ctgan_synthesizer(demo_single_table_data_pos_neg_connector): +def ctgan_synthesizer( + demo_single_table_data_pos_neg_connector, + demo_single_table_data_pos_neg_metadata +): yield Synthesizer( + metadata=demo_single_table_data_pos_neg_metadata, model=CTGANSynthesizerModel(epochs=1), data_connector=demo_single_table_data_pos_neg_connector, ) def test_ctgan_synthesizer_with_pos_neg( - ctgan_synthesizer: Synthesizer, - demo_single_table_data_pos_neg_metadata, - demo_single_table_data_pos_neg_loader, - demo_single_table_data_pos_neg, + ctgan_synthesizer: Synthesizer, + demo_single_table_data_pos_neg ): original_data = demo_single_table_data_pos_neg # Train the CTGAN model - ctgan_synthesizer.fit(demo_single_table_data_pos_neg_metadata) + ctgan_synthesizer.fit() + ctgan: CTGANSynthesizerModel = ctgan_synthesizer.model + transformer: DataTransformer = ctgan._transformer + transform_list = transformer._column_transform_info_list + transformed_data = ctgan._ndarry_loader.get_all() + + current_dim = 0 + for item in transform_list: + span_info: List[SpanInfo] = item.output_info + col_dim = item.output_dimensions + current_data = transformed_data[:, current_dim:current_dim + col_dim] + current_dim += col_dim + col = item.column_name + if col in ["cat_freq", "cat_thres_freq"]: + assert isinstance(item.transform, NormalizedFrequencyEncoder) + assert col_dim == 1 + assert len(span_info) == 1 + assert span_info[0].activation_fn == "liner" + assert len(item.transform.intervals) == original_data[col].nunique(dropna=False) + assert (current_data >= -1).all() and (current_data <= 1).all() + elif col in ["cat_thres_label", "cat_label"]: + assert isinstance(item.transform, NormalizedLabelEncoder) + assert col_dim == 1 + assert len(span_info) == 1 + assert span_info[0].activation_fn == "liner" + assert len(item.transform.categories_to_values.keys()) == original_data[col].nunique(dropna=False) + assert (current_data >= -1).all() and (current_data <= 1).all() + elif col in ["cat_onehot"]: + assert isinstance(item.transform, OneHotEncoder) + nunique = original_data[col].nunique(dropna=False) + assert col_dim == nunique + assert len(span_info) == 1 + assert span_info[0].activation_fn == "softmax" + assert len(item.transform.dummies) == nunique + assert np.all((current_data == 0) | (current_data == 1)) + else: + assert isinstance(item.transform, ClusterBasedNormalizer) + sampled_data = ctgan_synthesizer.sample(1000) # Check each column for column in original_data.columns: # Skip columns that are identifiers or not intended for positivity checks - if column == "int_id": + if column in ["int_id"] or column.startswith("cat_"): continue is_all_positive = (original_data[column] >= 0).all() @@ -99,12 +136,12 @@ def test_ctgan_synthesizer_with_pos_neg( if is_all_positive: # Assert that the sampled_data column is also all positive assert ( - sampled_data[column] >= 0 + sampled_data[column] >= 0 ).all(), f"Column '{column}' in sampled data should be all positive." elif is_all_negative: # Assert that the sampled_data column is also all negative assert ( - sampled_data[column] <= 0 + sampled_data[column] <= 0 ).all(), f"Column '{column}' in sampled data should be all negative."