Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 1, 2024
1 parent 0822085 commit 1ea834a
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 78 deletions.
83 changes: 44 additions & 39 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def check_column_list(cls, value) -> Any:
"""

def get_column_encoder_by_categorical_threshold(
self, num_categories: int
self, num_categories: int
) -> Union[CategoricalEncoderType, None]:
encoder_type = None
if self.categorical_threshold is None:
Expand Down Expand Up @@ -140,16 +140,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -209,9 +209,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand Down Expand Up @@ -307,14 +307,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -374,12 +374,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -518,19 +518,24 @@ def check(self):
if self.categorical_encoder is not None:
for i in self.categorical_encoder.keys():
if not isinstance(i, str) or i not in self.discrete_columns:
raise MetadataInvalidError(f"categorical_encoder key {i} is invalid, it should be an str and is a discrete column name.")
raise MetadataInvalidError(
f"categorical_encoder key {i} is invalid, it should be an str and is a discrete column name."
)
if self.categorical_encoder.values() not in CategoricalEncoderType:
raise MetadataInvalidError(f"In categorical_encoder values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.")
raise MetadataInvalidError(
f"In categorical_encoder values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}."
)

if self.categorical_threshold is not None:
for i in self.categorical_threshold.keys():
if not isinstance(i, int) or i < 0:
raise MetadataInvalidError(f"categorical threshold {i} is invalid, it should be an positive int.")
raise MetadataInvalidError(
f"categorical threshold {i} is invalid, it should be an positive int."
)
if self.categorical_threshold.values() not in CategoricalEncoderType:
raise MetadataInvalidError(
f"In categorical_threshold values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.")


f"In categorical_threshold values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}."
)

logger.debug("Metadata check succeed.")

Expand Down Expand Up @@ -580,10 +585,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand All @@ -605,7 +610,7 @@ def get_column_pii(self, column_name: str):
return False

def change_column_type(
self, column_names: str | List[str], column_original_type: str, column_new_type: str
self, column_names: str | List[str], column_original_type: str, column_new_type: str
):
"""Change the type of column."""
if not column_names:
Expand Down
47 changes: 25 additions & 22 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import List, Tuple, NamedTuple, Dict, Any, Callable, Type
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Type

import numpy as np
import pandas as pd
Expand All @@ -26,27 +26,26 @@
)
from sdgx.utils import logger

CategoricalEncoderParams = NamedTuple("CategoricalEncoderParams", (
("encoder", Callable[[], CategoricalEncoderInstanceType]),
("categories_caculator", Callable[[CategoricalEncoderInstanceType], int]),
("activate_fn", ActivationFuncType)
))
CategoricalEncoderParams = NamedTuple(
"CategoricalEncoderParams",
(
("encoder", Callable[[], CategoricalEncoderInstanceType]),
("categories_caculator", Callable[[CategoricalEncoderInstanceType], int]),
("activate_fn", ActivationFuncType),
),
)
CategoricalEncoderMapper: Dict[CategoricalEncoderType, CategoricalEncoderParams] = {
CategoricalEncoderType.ONEHOT: CategoricalEncoderParams(
lambda: OneHotEncoder(),
lambda encoder: len(encoder.dummies),
ActivationFuncType.SOFTMAX
lambda: OneHotEncoder(), lambda encoder: len(encoder.dummies), ActivationFuncType.SOFTMAX
),
CategoricalEncoderType.LABEL: CategoricalEncoderParams(
lambda: NormalizedLabelEncoder(order_by="alphabetical"),
lambda encoder: 1,
ActivationFuncType.LINEAR
ActivationFuncType.LINEAR,
),
CategoricalEncoderType.FREQUENCY: CategoricalEncoderParams(
lambda: NormalizedFrequencyEncoder(),
lambda encoder: 1,
ActivationFuncType.LINEAR
)
lambda: NormalizedFrequencyEncoder(), lambda encoder: 1, ActivationFuncType.LINEAR
),
}


Expand All @@ -71,7 +70,7 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, metadata=None):
self._weight_threshold = weight_threshold

def _fit_categorical_encoder(
self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType
self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType
) -> Tuple[CategoricalEncoderInstanceType, int, ActivationFuncType]:
if encoder_type not in CategoricalEncoderMapper.keys():
raise ValueError("Unsupported encoder type {0}.".format(encoder_type))
Expand Down Expand Up @@ -132,19 +131,23 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
# if the encoder is onehot, or not be specified.
num_categories = -1 # if zero may cause crash to onehot.
if encoder_type == "onehot":
encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type)
encoder, num_categories, activate_fn = self._fit_categorical_encoder(
column_name, data, encoder_type
)

# if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type and self.metadata and num_categories != -1:
encoder_type = (
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
)

if encoder_type == "onehot":
pass
else:
encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type)
encoder, num_categories, activate_fn = self._fit_categorical_encoder(
column_name, data, encoder_type
)

assert encoder and activate_fn

Expand Down Expand Up @@ -242,7 +245,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo

loader = NDArrayLoader.get_auto_save(raw_data)
for ndarray in tqdm.tqdm(
p(processes), desc="Transforming data", total=len(processes), delay=3
p(processes), desc="Transforming data", total=len(processes), delay=3
):
loader.store(ndarray.astype(float))
return loader
Expand Down Expand Up @@ -288,10 +291,10 @@ def inverse_transform(self, data, sigmas=None):
column_names = []

for column_transform_info in tqdm.tqdm(
self._column_transform_info_list, desc="Inverse transforming", delay=3
self._column_transform_info_list, desc="Inverse transforming", delay=3
):
dim = column_transform_info.output_dimensions
column_data = data[:, st: st + dim]
column_data = data[:, st : st + dim]
if column_transform_info.column_type == "continuous":
recovered_column_data = self._inverse_transform_continuous(
column_transform_info, column_data, sigmas, st
Expand Down
12 changes: 6 additions & 6 deletions sdgx/models/components/optimize/sdv_ctgan/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(self, dim: int, activation_fn: ActivationFuncType | str):

class ColumnTransformInfo:
def __init__(
self,
column_name: str,
column_type: ColumnTransformType | str,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int,
self,
column_name: str,
column_type: ColumnTransformType | str,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int,
):
self.column_name: str = column_name
self.column_type: ColumnTransformType = ColumnTransformType(column_type)
Expand Down
3 changes: 2 additions & 1 deletion sdgx/models/components/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum, EnumMeta
from functools import cached_property
from typing import List, Iterable
from typing import Iterable, List

import numpy as np

Expand All @@ -22,6 +22,7 @@ def __contains__(cls, item):
class StrValuedBaseEnum(Enum, metaclass=StrValuedEnumMeta):
def __hash__(self):
return hash(self.value)

@property
def value(self):
return str(super().value)
Expand Down
16 changes: 10 additions & 6 deletions tests/data_models/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import pytest

from sdgx.models.components.utils import StrValuedBaseEnum as SE


class A(SE):
a = "1"
b = "2"


def test_se():
a = A("1")
b = A("2")
assert isinstance(a.value, str) and a.value == "1"
assert b.value == "2"
assert A.values == {"1", "2"}
assert a != b
assert ['1', '2', '3'] not in A
assert '1' in A
assert ['2'] in A
assert ['1', '2'] in A
assert '1' == a
assert 1 != a
assert ["1", "2", "3"] not in A
assert "1" in A
assert ["2"] in A
assert ["1", "2"] in A
assert "1" == a
assert 1 != a
8 changes: 4 additions & 4 deletions tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata, CategoricalEncoderType
from sdgx.data_models.metadata import CategoricalEncoderType, Metadata
from sdgx.exceptions import MetadataInvalidError


Expand Down Expand Up @@ -138,11 +138,10 @@ def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata):
# check dump
assert "column_data_type" in demo_multi_data_child_matadata.dump().keys()


def test_meta_encoder(metadata: Metadata):
metadata = metadata.model_copy()
metadata.categorical_threshold = {
1: "aaa"
}
metadata.categorical_threshold = {1: "aaa"}
with pytest.raises(MetadataInvalidError):
metadata.check()
metadata.categorical_threshold[1] = CategoricalEncoderType.ONEHOT
Expand All @@ -164,5 +163,6 @@ def test_meta_encoder(metadata: Metadata):
with pytest.raises(MetadataInvalidError):
metadata.check()


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit 1ea834a

Please sign in to comment.