Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cyantangerine committed Dec 1, 2024
1 parent 1813c2e commit 0822085
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 98 deletions.
90 changes: 56 additions & 34 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
from sdgx.data_models.inspectors.base import RelationshipInspector
from sdgx.data_models.inspectors.manager import InspectorManager
from sdgx.exceptions import MetadataInitError, MetadataInvalidError
from sdgx.models.components.utils import StrValuedBaseEnum
from sdgx.utils import logger

CategoricalEncoderType = Literal["onehot", "label", "frequency"]

class CategoricalEncoderType(StrValuedBaseEnum):
ONEHOT = "onehot"
LABEL = "label"
FREQUENCY = "frequency"


class Metadata(BaseModel):
Expand Down Expand Up @@ -90,7 +95,7 @@ def check_column_list(cls, value) -> Any:
"""

def get_column_encoder_by_categorical_threshold(
self, num_categories: int
self, num_categories: int
) -> Union[CategoricalEncoderType, None]:
encoder_type = None
if self.categorical_threshold is None:
Expand Down Expand Up @@ -135,16 +140,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -204,9 +209,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand Down Expand Up @@ -302,14 +307,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -369,12 +374,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -510,6 +515,23 @@ def check(self):
f"Found undefined column: {set(all_dtype_columns) - set(self.column_list)}."
)

if self.categorical_encoder is not None:
for i in self.categorical_encoder.keys():
if not isinstance(i, str) or i not in self.discrete_columns:
raise MetadataInvalidError(f"categorical_encoder key {i} is invalid, it should be an str and is a discrete column name.")
if self.categorical_encoder.values() not in CategoricalEncoderType:
raise MetadataInvalidError(f"In categorical_encoder values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.")

if self.categorical_threshold is not None:
for i in self.categorical_threshold.keys():
if not isinstance(i, int) or i < 0:
raise MetadataInvalidError(f"categorical threshold {i} is invalid, it should be an positive int.")
if self.categorical_threshold.values() not in CategoricalEncoderType:
raise MetadataInvalidError(
f"In categorical_threshold values, categorical encoder type invalid, now supports {list(CategoricalEncoderType)}.")



logger.debug("Metadata check succeed.")

def update_primary_key(self, primary_keys: Iterable[str] | str):
Expand Down Expand Up @@ -558,10 +580,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand All @@ -583,7 +605,7 @@ def get_column_pii(self, column_name: str):
return False

def change_column_type(
self, column_names: str | List[str], column_original_type: str, column_new_type: str
self, column_names: str | List[str], column_original_type: str, column_new_type: str
):
"""Change the type of column."""
if not column_names:
Expand Down
113 changes: 62 additions & 51 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import List, Tuple
from typing import List, Tuple, NamedTuple, Dict, Any, Callable, Type

import numpy as np
import pandas as pd
Expand All @@ -26,6 +26,29 @@
)
from sdgx.utils import logger

CategoricalEncoderParams = NamedTuple("CategoricalEncoderParams", (
("encoder", Callable[[], CategoricalEncoderInstanceType]),
("categories_caculator", Callable[[CategoricalEncoderInstanceType], int]),
("activate_fn", ActivationFuncType)
))
CategoricalEncoderMapper: Dict[CategoricalEncoderType, CategoricalEncoderParams] = {
CategoricalEncoderType.ONEHOT: CategoricalEncoderParams(
lambda: OneHotEncoder(),
lambda encoder: len(encoder.dummies),
ActivationFuncType.SOFTMAX
),
CategoricalEncoderType.LABEL: CategoricalEncoderParams(
lambda: NormalizedLabelEncoder(order_by="alphabetical"),
lambda encoder: 1,
ActivationFuncType.LINEAR
),
CategoricalEncoderType.FREQUENCY: CategoricalEncoderParams(
lambda: NormalizedFrequencyEncoder(),
lambda encoder: 1,
ActivationFuncType.LINEAR
)
}


class DataTransformer(object):
"""Data Transformer.
Expand All @@ -48,51 +71,16 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, metadata=None):
self._weight_threshold = weight_threshold

def _fit_categorical_encoder(
self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType | None
self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType
) -> Tuple[CategoricalEncoderInstanceType, int, ActivationFuncType]:
selected_encoder_type = None
# Load encoder from metadata
if encoder_type is None and self.metadata:
selected_encoder_type = encoder_type = self.metadata.get_column_encoder_by_name(
column_name
)
# if the encoder is not be specified, using onehot.
if encoder_type is None:
encoder_type = "onehot"
# if the encoder is onehot, or not be specified.
num_categories = -1 # if zero may cause crash to onehot.
if encoder_type == "onehot":
encoder = OneHotEncoder()
encoder.fit(data, column_name)
num_categories = len(encoder.dummies)
# Notice: if `activate_fn` is modified, the function `is_onehot_encoding_column` in `DataSampler` should also be modified.
activate_fn = "softmax"

# if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type and self.metadata and num_categories != -1:
encoder_type = (
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
)

if encoder_type == "onehot":
pass
elif encoder_type == "label":
encoder = NormalizedLabelEncoder(order_by="alphabetical")
encoder.fit(data, column_name)
num_categories = 1
activate_fn = "linear"
elif encoder_type == "frequency":
encoder = NormalizedFrequencyEncoder()
encoder.fit(data, column_name)
num_categories = 1
activate_fn = "linear"
else:
raise ValueError(
"column encoder must be either 'onehot'(default), 'label' or 'frequency', not {0}".format(
encoder_type
)
)
if encoder_type not in CategoricalEncoderMapper.keys():
raise ValueError("Unsupported encoder type {0}.".format(encoder_type))
p: CategoricalEncoderParams = CategoricalEncoderMapper[encoder_type]
encoder = p.encoder()
encoder.fit(data, column_name)
num_categories = p.categories_caculator(encoder)
# Notice: if `activate_fn` is modified, the function `is_onehot_encoding_column` in `DataSampler` should also be modified.
activate_fn = p.activate_fn
return encoder, num_categories, activate_fn

def _fit_continuous(self, data):
Expand Down Expand Up @@ -130,11 +118,34 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
namedtuple:
A ``ColumnTransformInfo`` object.
"""
encoder, activate_fn, selected_encoder_type = None, None, None
column_name = data.columns[0]

encoder, num_categories, activate_fn = self._fit_categorical_encoder(
column_name, data, encoder_type
)
# Load encoder from metadata
if encoder_type is None and self.metadata:
selected_encoder_type = encoder_type = self.metadata.get_column_encoder_by_name(
column_name
)
# if the encoder is not be specified, using onehot.
if encoder_type is None:
encoder_type = "onehot"
# if the encoder is onehot, or not be specified.
num_categories = -1 # if zero may cause crash to onehot.
if encoder_type == "onehot":
encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type)

# if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type and self.metadata and num_categories != -1:
encoder_type = (
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
)

if encoder_type == "onehot":
pass
else:
encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type)

assert encoder and activate_fn

return ColumnTransformInfo(
Expand Down Expand Up @@ -231,7 +242,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo

loader = NDArrayLoader.get_auto_save(raw_data)
for ndarray in tqdm.tqdm(
p(processes), desc="Transforming data", total=len(processes), delay=3
p(processes), desc="Transforming data", total=len(processes), delay=3
):
loader.store(ndarray.astype(float))
return loader
Expand Down Expand Up @@ -277,10 +288,10 @@ def inverse_transform(self, data, sigmas=None):
column_names = []

for column_transform_info in tqdm.tqdm(
self._column_transform_info_list, desc="Inverse transforming", delay=3
self._column_transform_info_list, desc="Inverse transforming", delay=3
):
dim = column_transform_info.output_dimensions
column_data = data[:, st : st + dim]
column_data = data[:, st: st + dim]
if column_transform_info.column_type == "continuous":
recovered_column_data = self._inverse_transform_continuous(
column_transform_info, column_data, sigmas, st
Expand Down
37 changes: 25 additions & 12 deletions sdgx/models/components/optimize/sdv_ctgan/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from typing import List, Literal, Union
from enum import unique
from typing import List, Union

from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
NormalizedFrequencyEncoder,
NormalizedLabelEncoder,
OneHotEncoder,
)
from sdgx.models.components.utils import StrValuedBaseEnum

CategoricalEncoderInstanceType = Union[
OneHotEncoder, NormalizedFrequencyEncoder, NormalizedLabelEncoder
Expand All @@ -16,27 +18,38 @@
TransformerEncoderInstanceType = Union[
CategoricalEncoderInstanceType, ContinuousEncoderInstanceType
]
ActivationFuncType = Literal["softmax", "tanh", "linear"]
ColumnTransformType = Literal["discrete", "continuous"]


@unique
class ActivationFuncType(StrValuedBaseEnum):
SOFTMAX = "softmax"
TANH = "tanh"
LINEAR = "linear"


@unique
class ColumnTransformType(StrValuedBaseEnum):
DISCRETE = "discrete"
CONTINUOUS = "continuous"


class SpanInfo:
def __init__(self, dim: int, activation_fn: ActivationFuncType):
def __init__(self, dim: int, activation_fn: ActivationFuncType | str):
self.dim: int = dim
self.activation_fn: ActivationFuncType = activation_fn
self.activation_fn: ActivationFuncType = ActivationFuncType(activation_fn)


class ColumnTransformInfo:
def __init__(
self,
column_name: str,
column_type: ColumnTransformType,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int,
self,
column_name: str,
column_type: ColumnTransformType | str,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int,
):
self.column_name: str = column_name
self.column_type: str = column_type
self.column_type: ColumnTransformType = ColumnTransformType(column_type)
self.transform: TransformerEncoderInstanceType = transform
self.output_info: List[SpanInfo] = output_info
self.output_dimensions: int = output_dimensions
Loading

0 comments on commit 0822085

Please sign in to comment.