Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 1, 2024
1 parent 96b6eac commit 1813c2e
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 39 deletions.
14 changes: 8 additions & 6 deletions sdgx/data_connectors/dataframe_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from functools import cached_property
from typing import Callable, Generator

import pandas as pd

from sdgx.data_connectors.base import DataConnector
Expand All @@ -28,10 +29,10 @@ class DataFrameConnector(DataConnector):
"""

def __init__(
self,
df: pd.DataFrame,
*args,
**kwargs,
self,
df: pd.DataFrame,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.df: pd.DataFrame = df
Expand All @@ -41,7 +42,7 @@ def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | Non
if offset >= length:
return None
limit = limit or length
return self.df.iloc[offset: min(offset + limit, length)]
return self.df.iloc[offset : min(offset + limit, length)]

def _columns(self) -> list[str]:
return list(self.df.columns)
Expand All @@ -52,14 +53,15 @@ def generator() -> Generator[pd.DataFrame, None, None]:
if offset < length:
current = offset
while current < length:
yield self.df.iloc[current: min(current + chunksize, length)]
yield self.df.iloc[current : min(current + chunksize, length)]
current += chunksize

return generator()


from sdgx.data_connectors.extension import hookimpl


@hookimpl
def register(manager):
manager.register("DataFrameConnector", DataFrameConnector)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
FrequencyEncoder
FrequencyEncoder,
)

# TODO(Enhance) - Use different type of Encoder for discrete, like ordered columns, high cardinality columns...
Expand Down
7 changes: 6 additions & 1 deletion sdgx/models/components/optimize/sdv_ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
class DataSampler(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""

def __init__(self, data: NDArrayLoader | np.ndarray, output_info: List[List[SpanInfo]], log_frequency: bool):
def __init__(
self,
data: NDArrayLoader | np.ndarray,
output_info: List[List[SpanInfo]],
log_frequency: bool,
):
self._data: NDArrayLoader | np.ndarray = data

def is_onehot_encoding_column(column_info: List[SpanInfo]):
Expand Down
27 changes: 13 additions & 14 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from sdgx.data_models.metadata import CategoricalEncoderType, Metadata
from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader
from sdgx.models.components.optimize.sdv_ctgan.types import (
CategoricalEncoderInstanceType,
ActivationFuncType,
CategoricalEncoderInstanceType,
ColumnTransformInfo,
SpanInfo,
ColumnTransformInfo
)
from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
OneHotEncoder,
NormalizedFrequencyEncoder,
NormalizedLabelEncoder
NormalizedLabelEncoder,
OneHotEncoder,
)
from sdgx.utils import logger

Expand All @@ -48,10 +48,7 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, metadata=None):
self._weight_threshold = weight_threshold

def _fit_categorical_encoder(
self,
column_name: str,
data: pd.DataFrame,
encoder_type: CategoricalEncoderType | None
self, column_name: str, data: pd.DataFrame, encoder_type: CategoricalEncoderType | None
) -> Tuple[CategoricalEncoderInstanceType, int, ActivationFuncType]:
selected_encoder_type = None
# Load encoder from metadata
Expand All @@ -74,8 +71,8 @@ def _fit_categorical_encoder(
# if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type and self.metadata and num_categories != -1:
encoder_type = (
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
)

if encoder_type == "onehot":
Expand Down Expand Up @@ -135,7 +132,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
"""
column_name = data.columns[0]

encoder, num_categories, activate_fn = self._fit_categorical_encoder(column_name, data, encoder_type)
encoder, num_categories, activate_fn = self._fit_categorical_encoder(
column_name, data, encoder_type
)
assert encoder and activate_fn

return ColumnTransformInfo(
Expand Down Expand Up @@ -232,7 +231,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo

loader = NDArrayLoader.get_auto_save(raw_data)
for ndarray in tqdm.tqdm(
p(processes), desc="Transforming data", total=len(processes), delay=3
p(processes), desc="Transforming data", total=len(processes), delay=3
):
loader.store(ndarray.astype(float))
return loader
Expand Down Expand Up @@ -278,10 +277,10 @@ def inverse_transform(self, data, sigmas=None):
column_names = []

for column_transform_info in tqdm.tqdm(
self._column_transform_info_list, desc="Inverse transforming", delay=3
self._column_transform_info_list, desc="Inverse transforming", delay=3
):
dim = column_transform_info.output_dimensions
column_data = data[:, st: st + dim]
column_data = data[:, st : st + dim]
if column_transform_info.column_type == "continuous":
recovered_column_data = self._inverse_transform_continuous(
column_transform_info, column_data, sigmas, st
Expand Down
28 changes: 17 additions & 11 deletions sdgx/models/components/optimize/sdv_ctgan/types.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

from typing import Union, Literal, List
from typing import List, Literal, Union

from sdgx.models.components.sdv_rdt.transformers import (
OneHotEncoder,
ClusterBasedNormalizer,
NormalizedFrequencyEncoder,
NormalizedLabelEncoder
NormalizedLabelEncoder,
OneHotEncoder,
)

CategoricalEncoderInstanceType = Union[OneHotEncoder, NormalizedFrequencyEncoder, NormalizedLabelEncoder]
CategoricalEncoderInstanceType = Union[
OneHotEncoder, NormalizedFrequencyEncoder, NormalizedLabelEncoder
]
ContinuousEncoderInstanceType = Union[ClusterBasedNormalizer]
TransformerEncoderInstanceType = Union[CategoricalEncoderInstanceType, ContinuousEncoderInstanceType]
TransformerEncoderInstanceType = Union[
CategoricalEncoderInstanceType, ContinuousEncoderInstanceType
]
ActivationFuncType = Literal["softmax", "tanh", "linear"]
ColumnTransformType = Literal["discrete", "continuous"]

Expand All @@ -23,12 +27,14 @@ def __init__(self, dim: int, activation_fn: ActivationFuncType):


class ColumnTransformInfo:
def __init__(self,
column_name: str,
column_type: ColumnTransformType,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int):
def __init__(
self,
column_name: str,
column_type: ColumnTransformType,
transform: TransformerEncoderInstanceType,
output_info: List[SpanInfo],
output_dimensions: int,
):
self.column_name: str = column_name
self.column_type: str = column_type
self.transform: TransformerEncoderInstanceType = transform
Expand Down
4 changes: 2 additions & 2 deletions sdgx/models/components/sdv_rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
CustomLabelEncoder,
FrequencyEncoder,
LabelEncoder,
OneHotEncoder,
NormalizedFrequencyEncoder,
NormalizedLabelEncoder,
NormalizedFrequencyEncoder
OneHotEncoder,
)
from sdgx.models.components.sdv_rdt.transformers.datetime import (
OptimizedTimestampEncoder,
Expand Down
6 changes: 2 additions & 4 deletions tests/test_ctgan_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@

from sdgx.data_connectors.dataframe_connector import DataFrameConnector
from sdgx.data_models.metadata import Metadata
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import (
DataTransformer,
)
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.optimize.sdv_ctgan.types import SpanInfo
from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
NormalizedFrequencyEncoder,
NormalizedLabelEncoder,
OneHotEncoder,
ClusterBasedNormalizer
)
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
Expand Down

0 comments on commit 1813c2e

Please sign in to comment.