diff --git a/sdgx/data_models/metadata.py b/sdgx/data_models/metadata.py index 757a3263..075a907d 100644 --- a/sdgx/data_models/metadata.py +++ b/sdgx/data_models/metadata.py @@ -95,7 +95,7 @@ def check_column_list(cls, value) -> Any: """ def get_column_encoder_by_categorical_threshold( - self, num_categories: int + self, num_categories: int ) -> Union[CategoricalEncoderType, None]: encoder_type = None if self.categorical_threshold is None: @@ -140,16 +140,16 @@ def __eq__(self, other): if not isinstance(other, Metadata): return super().__eq__(other) return ( - set(self.tag_fields) == set(other.tag_fields) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.tag_fields, other.tag_fields)) - ) - and all( - self.get(key) == other.get(key) - for key in set(chain(self.format_fields, other.format_fields)) - ) - and self.version == other.version + set(self.tag_fields) == set(other.tag_fields) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.tag_fields, other.tag_fields)) + ) + and all( + self.get(key) == other.get(key) + for key in set(chain(self.format_fields, other.format_fields)) + ) + and self.version == other.version ) def query(self, field: str) -> Iterable[str]: @@ -209,9 +209,9 @@ def set(self, key: str, value: Any): old_value = self.get(key) if ( - key in self.model_fields - and key not in self.tag_fields - and key not in self.format_fields + key in self.model_fields + and key not in self.tag_fields + and key not in self.format_fields ): raise MetadataInitError( f"Set {key} not in tag_fields, try set it directly as m.{key} = value" @@ -307,14 +307,14 @@ def update(self, attributes: dict[str, Any]): @classmethod def from_dataloader( - cls, - dataloader: DataLoader, - max_chunk: int = 10, - primary_keys: Set[str] = None, - include_inspectors: Iterable[str] | None = None, - exclude_inspectors: Iterable[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + dataloader: DataLoader, + max_chunk: int = 10, + primary_keys: Set[str] = None, + include_inspectors: Iterable[str] | None = None, + exclude_inspectors: Iterable[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataLoader and Inspectors @@ -374,12 +374,12 @@ def from_dataloader( @classmethod def from_dataframe( - cls, - df: pd.DataFrame, - include_inspectors: list[str] | None = None, - exclude_inspectors: list[str] | None = None, - inspector_init_kwargs: dict[str, Any] | None = None, - check: bool = False, + cls, + df: pd.DataFrame, + include_inspectors: list[str] | None = None, + exclude_inspectors: list[str] | None = None, + inspector_init_kwargs: dict[str, Any] | None = None, + check: bool = False, ) -> "Metadata": """Initialize a metadata from DataFrame and Inspectors @@ -518,19 +518,24 @@ def check(self): if self.categorical_encoder is not None: for i in self.categorical_encoder.keys(): if not isinstance(i, str) or i not in self.discrete_columns: - raise MetadataInvalidError(f"categorical_encoder key {i} is invalid, it should be an str and is a discrete column name.") + raise MetadataInvalidError( + f"categorical_encoder key {i} is invalid, it should be an str and is a discrete column name." + ) if self.categorical_encoder.values() not in CategoricalEncoderType: - raise MetadataInvalidError(f"In categorical_encoder values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.") + raise MetadataInvalidError( + f"In categorical_encoder values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}." + ) if self.categorical_threshold is not None: for i in self.categorical_threshold.keys(): if not isinstance(i, int) or i < 0: - raise MetadataInvalidError(f"categorical threshold {i} is invalid, it should be an positive int.") + raise MetadataInvalidError( + f"categorical threshold {i} is invalid, it should be an positive int." + ) if self.categorical_threshold.values() not in CategoricalEncoderType: raise MetadataInvalidError( - f"In categorical_threshold values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.") - - + f"In categorical_threshold values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}." + ) logger.debug("Metadata check succeed.") @@ -580,10 +585,10 @@ def get_column_data_type(self, column_name: str): # find the dtype who has most high inspector level for each_key in list(self.model_fields.keys()) + list(self._extend.keys()): if ( - each_key != "pii_columns" - and each_key.endswith("_columns") - and column_name in self.get(each_key) - and current_level < self.column_inspect_level[each_key] + each_key != "pii_columns" + and each_key.endswith("_columns") + and column_name in self.get(each_key) + and current_level < self.column_inspect_level[each_key] ): current_level = self.column_inspect_level[each_key] current_type = each_key @@ -605,7 +610,7 @@ def get_column_pii(self, column_name: str): return False def change_column_type( - self, column_names: str | List[str], column_original_type: str, column_new_type: str + self, column_names: str | List[str], column_original_type: str, column_new_type: str ): """Change the type of column.""" if not column_names: diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 0593ce89..ed09e764 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Tuple, NamedTuple, Dict, Any, Callable, Type +from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Type import numpy as np import pandas as pd @@ -26,27 +26,26 @@ ) from sdgx.utils import logger -CategoricalEncoderParams = NamedTuple("CategoricalEncoderParams", ( - ("encoder", Callable[[], CategoricalEncoderInstanceType]), - ("categories_caculator", Callable[[CategoricalEncoderInstanceType], int]), - ("activate_fn", ActivationFuncType) -)) +CategoricalEncoderParams = NamedTuple( + "CategoricalEncoderParams", + ( + ("encoder", Callable[[], CategoricalEncoderInstanceType]), + ("categories_caculator", Callable[[CategoricalEncoderInstanceType], int]), + ("activate_fn", ActivationFuncType), + ), +) CategoricalEncoderMapper: Dict[CategoricalEncoderType, CategoricalEncoderParams] = { CategoricalEncoderType.ONEHOT: CategoricalEncoderParams( - lambda: OneHotEncoder(), - lambda encoder: len(encoder.dummies), - ActivationFuncType.SOFTMAX + lambda: OneHotEncoder(), lambda encoder: len(encoder.dummies), ActivationFuncType.SOFTMAX ), CategoricalEncoderType.LABEL: CategoricalEncoderParams( lambda: NormalizedLabelEncoder(order_by="alphabetical"), lambda encoder: 1, - ActivationFuncType.LINEAR + ActivationFuncType.LINEAR, ), CategoricalEncoderType.FREQUENCY: CategoricalEncoderParams( - lambda: NormalizedFrequencyEncoder(), - lambda encoder: 1, - ActivationFuncType.LINEAR - ) + lambda: NormalizedFrequencyEncoder(), lambda encoder: 1, ActivationFuncType.LINEAR + ), } @@ -71,7 +70,7 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, metadata=None): self._weight_threshold = weight_threshold def _fit_categorical_encoder( - self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType + self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType ) -> Tuple[CategoricalEncoderInstanceType, int, ActivationFuncType]: if encoder_type not in CategoricalEncoderMapper.keys(): raise ValueError("Unsupported encoder type {0}.".format(encoder_type)) @@ -132,19 +131,23 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None): # if the encoder is onehot, or not be specified. num_categories = -1 # if zero may cause crash to onehot. if encoder_type == "onehot": - encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type) + encoder, num_categories, activate_fn = self._fit_categorical_encoder( + column_name, data, encoder_type + ) # if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder. if not selected_encoder_type and self.metadata and num_categories != -1: encoder_type = ( - self.metadata.get_column_encoder_by_categorical_threshold(num_categories) - or encoder_type + self.metadata.get_column_encoder_by_categorical_threshold(num_categories) + or encoder_type ) if encoder_type == "onehot": pass else: - encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type) + encoder, num_categories, activate_fn = self._fit_categorical_encoder( + column_name, data, encoder_type + ) assert encoder and activate_fn @@ -242,7 +245,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo loader = NDArrayLoader.get_auto_save(raw_data) for ndarray in tqdm.tqdm( - p(processes), desc="Transforming data", total=len(processes), delay=3 + p(processes), desc="Transforming data", total=len(processes), delay=3 ): loader.store(ndarray.astype(float)) return loader @@ -288,10 +291,10 @@ def inverse_transform(self, data, sigmas=None): column_names = [] for column_transform_info in tqdm.tqdm( - self._column_transform_info_list, desc="Inverse transforming", delay=3 + self._column_transform_info_list, desc="Inverse transforming", delay=3 ): dim = column_transform_info.output_dimensions - column_data = data[:, st: st + dim] + column_data = data[:, st : st + dim] if column_transform_info.column_type == "continuous": recovered_column_data = self._inverse_transform_continuous( column_transform_info, column_data, sigmas, st diff --git a/sdgx/models/components/optimize/sdv_ctgan/types.py b/sdgx/models/components/optimize/sdv_ctgan/types.py index 2de43368..945c1b28 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/types.py +++ b/sdgx/models/components/optimize/sdv_ctgan/types.py @@ -41,12 +41,12 @@ def __init__(self, dim: int, activation_fn: ActivationFuncType | str): class ColumnTransformInfo: def __init__( - self, - column_name: str, - column_type: ColumnTransformType | str, - transform: TransformerEncoderInstanceType, - output_info: List[SpanInfo], - output_dimensions: int, + self, + column_name: str, + column_type: ColumnTransformType | str, + transform: TransformerEncoderInstanceType, + output_info: List[SpanInfo], + output_dimensions: int, ): self.column_name: str = column_name self.column_type: ColumnTransformType = ColumnTransformType(column_type) diff --git a/sdgx/models/components/utils.py b/sdgx/models/components/utils.py index f22b059b..c386c1f5 100644 --- a/sdgx/models/components/utils.py +++ b/sdgx/models/components/utils.py @@ -1,6 +1,6 @@ from enum import Enum, EnumMeta from functools import cached_property -from typing import List, Iterable +from typing import Iterable, List import numpy as np @@ -22,6 +22,7 @@ def __contains__(cls, item): class StrValuedBaseEnum(Enum, metaclass=StrValuedEnumMeta): def __hash__(self): return hash(self.value) + @property def value(self): return str(super().value) diff --git a/tests/data_models/test_enums.py b/tests/data_models/test_enums.py index 4398197f..9fb17931 100644 --- a/tests/data_models/test_enums.py +++ b/tests/data_models/test_enums.py @@ -1,9 +1,13 @@ import pytest from sdgx.models.components.utils import StrValuedBaseEnum as SE + + class A(SE): a = "1" b = "2" + + def test_se(): a = A("1") b = A("2") @@ -11,9 +15,9 @@ def test_se(): assert b.value == "2" assert A.values == {"1", "2"} assert a != b - assert ['1', '2', '3'] not in A - assert '1' in A - assert ['2'] in A - assert ['1', '2'] in A - assert '1' == a - assert 1 != a \ No newline at end of file + assert ["1", "2", "3"] not in A + assert "1" in A + assert ["2"] in A + assert ["1", "2"] in A + assert "1" == a + assert 1 != a diff --git a/tests/data_models/test_metadata.py b/tests/data_models/test_metadata.py index 8e6662d5..85916d19 100644 --- a/tests/data_models/test_metadata.py +++ b/tests/data_models/test_metadata.py @@ -6,7 +6,7 @@ from sdgx.data_connectors.csv_connector import CsvConnector from sdgx.data_loader import DataLoader -from sdgx.data_models.metadata import Metadata, CategoricalEncoderType +from sdgx.data_models.metadata import CategoricalEncoderType, Metadata from sdgx.exceptions import MetadataInvalidError @@ -138,11 +138,10 @@ def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata): # check dump assert "column_data_type" in demo_multi_data_child_matadata.dump().keys() + def test_meta_encoder(metadata: Metadata): metadata = metadata.model_copy() - metadata.categorical_threshold = { - 1: "aaa" - } + metadata.categorical_threshold = {1: "aaa"} with pytest.raises(MetadataInvalidError): metadata.check() metadata.categorical_threshold[1] = CategoricalEncoderType.ONEHOT @@ -164,5 +163,6 @@ def test_meta_encoder(metadata: Metadata): with pytest.raises(MetadataInvalidError): metadata.check() + if __name__ == "__main__": pytest.main(["-vv", "-s", __file__])