Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cyantangerine committed Nov 26, 2024
2 parents 7b6f414 + 226ddc5 commit fb008de
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 79 deletions.
12 changes: 6 additions & 6 deletions sdgx/data_connectors/dataframe_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

class DataFrameConnector(DataConnector):
def __init__(
self,
df: pd.DataFrame,
*args,
**kwargs,
self,
df: pd.DataFrame,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.df: pd.DataFrame = df
Expand All @@ -24,7 +24,7 @@ def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | Non
if offset >= length:
return None
limit = limit or length
return self.df.iloc[offset:min(offset + limit, length)]
return self.df.iloc[offset : min(offset + limit, length)]

def _columns(self) -> list[str]:
return list(self.df.columns)
Expand All @@ -35,7 +35,7 @@ def generator() -> Generator[pd.DataFrame, None, None]:
if offset < length:
current = offset
while current < length:
yield self.df.iloc[current:min(current + chunksize, length - 1)]
yield self.df.iloc[current : min(current + chunksize, length - 1)]
current += chunksize

return generator()
Expand Down
69 changes: 37 additions & 32 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pandas as pd
from pydantic import BaseModel, Field, field_validator

from sdgx.data_loader import DataLoader
from sdgx.data_models.inspectors.base import RelationshipInspector
from sdgx.data_models.inspectors.manager import InspectorManager
Expand Down Expand Up @@ -89,7 +90,11 @@ def check_column_list(cls, value) -> Any:
"""

def check_categorical_threshold(self, num_categories):
return num_categories > self.categorical_threshold if self.categorical_threshold is not None else False
return (
num_categories > self.categorical_threshold
if self.categorical_threshold is not None
else False
)

@property
def tag_fields(self) -> Iterable[str]:
Expand Down Expand Up @@ -117,16 +122,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -186,9 +191,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand Down Expand Up @@ -278,14 +283,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -345,12 +350,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -534,10 +539,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand Down
4 changes: 3 additions & 1 deletion sdgx/data_processors/formatters/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def datetime_formatter(each_value, datetime_format):
datetime_obj = datetime.strptime(str(each_value), datetime_format)
each_stamp = datetime.timestamp(datetime_obj)
except Exception as e:
logger.warning(f"An error occured when convert str to timestamp {e}, we set as mean.")
logger.warning(
f"An error occured when convert str to timestamp {e}, we set as mean."
)
logger.warning(f"Input parameters: ({str(each_value)}, {datetime_format})")
logger.warning(f"Input type: ({type(each_value)}, {type(datetime_format)})")
each_stamp = np.nan
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/components/optimize/ndarray_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def store(self, ndarray: ndarray):
np.save(self._get_cache_filename(self.store_index), ndarray)
self.store_index += 1
else:
for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1):
for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1):
self.ndarray_list.append(ndarray)
self.store_index += 1

Expand Down
12 changes: 10 additions & 2 deletions sdgx/models/components/optimize/sdv_ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def is_onehot_encoding_column(column_info: SpanInfo):

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
[column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)],
[
column_info[0].dim
for column_info in output_info
if is_onehot_encoding_column(column_info)
],
default=0,
)

Expand All @@ -56,7 +60,11 @@ def is_onehot_encoding_column(column_info: SpanInfo):
self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum(
[column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)]
[
column_info[0].dim
for column_info in output_info
if is_onehot_encoding_column(column_info)
]
)

st = 0
Expand Down
16 changes: 7 additions & 9 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from joblib import Parallel, delayed
from tqdm import autonotebook as tqdm

from sdgx.data_connectors.dataframe_connector import DataFrameConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata, CategoricalEncoderType
from sdgx.data_models.metadata import CategoricalEncoderType, Metadata
from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader
from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
OneHotEncoder,
)
from sdgx.models.components.sdv_rdt.transformers.categorical import LabelEncoder, NormalizedLabelEncoder
from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedLabelEncoder
from sdgx.utils import logger

SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"])
Expand Down Expand Up @@ -90,9 +89,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
activate_fn = "softmax"

checked = self.metadata.check_categorical_threshold(num_categories)
if encoder_type == 'onehot' or not checked:
if encoder_type == "onehot" or not checked:
pass
elif encoder_type == 'label':
elif encoder_type == "label":
encoder = NormalizedLabelEncoder(order_by="alphabetical")
encoder.fit(data, column_name)
num_categories = 1
Expand Down Expand Up @@ -162,7 +161,6 @@ def _transform_discrete(self, column_transform_info, data):
ohe = column_transform_info.transform
return ohe.transform(data).to_numpy()


def _synchronous_transform(self, raw_data, column_transform_info_list) -> NDArrayLoader:
"""Take a Pandas DataFrame and transform columns synchronous.
Expand Down Expand Up @@ -199,7 +197,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo

loader = NDArrayLoader.get_auto_save(raw_data)
for ndarray in tqdm.tqdm(
p(processes), desc="Transforming data", total=len(processes), delay=3
p(processes), desc="Transforming data", total=len(processes), delay=3
):
loader.store(ndarray.astype(float))
return loader
Expand Down Expand Up @@ -245,10 +243,10 @@ def inverse_transform(self, data, sigmas=None):
column_names = []

for column_transform_info in tqdm.tqdm(
self._column_transform_info_list, desc="Inverse transforming", delay=3
self._column_transform_info_list, desc="Inverse transforming", delay=3
):
dim = column_transform_info.output_dimensions
column_data = data[:, st: st + dim]
column_data = data[:, st : st + dim]
if column_transform_info.column_type == "continuous":
recovered_column_data = self._inverse_transform_continuous(
column_transform_info, column_data, sigmas, st
Expand Down
16 changes: 9 additions & 7 deletions sdgx/models/components/sdv_ctgan/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __getstate__(self):

random_states = self.random_states
if (
isinstance(random_states, tuple)
and isinstance(random_states[0], np.random.RandomState)
and isinstance(random_states[1], torch.Generator)
isinstance(random_states, tuple)
and isinstance(random_states[0], np.random.RandomState)
and isinstance(random_states[1], torch.Generator)
):
state["_numpy_random_state"] = random_states[0].get_state()
state["_torch_random_state"] = random_states[1].get_state()
Expand Down Expand Up @@ -112,7 +112,9 @@ def save(self, path):
self.set_device(device_backup)

@classmethod
def load(cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
def load(
cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""Load the model stored in the passed arg `path`."""
with open(path, "rb") as f:
model = cloudpickle.load(f)
Expand All @@ -135,9 +137,9 @@ def set_random_state(self, random_state):
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
Expand Down
7 changes: 4 additions & 3 deletions sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from sdgx.models.components.sdv_ctgan.data_sampler import DataSampler
from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BaseSynthesizer,
random_state, BatchedSynthesizer,
BatchedSynthesizer,
random_state,
)


Expand Down Expand Up @@ -262,7 +262,8 @@ def _cond_loss(self, data, c, m):
st_c = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if len(column_info) != 1 or span_info.activation_fn != "softmax": # TODO may need to revise for label encoder?
if len(column_info) != 1 or span_info.activation_fn != "softmax":
# TODO may need to revise for label encoder?
# not discrete column
st += span_info.dim
else:
Expand Down
24 changes: 12 additions & 12 deletions sdgx/models/components/sdv_ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BatchedSynthesizer,
random_state,
BatchedSynthesizer
)


Expand Down Expand Up @@ -85,7 +85,7 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
ed = st + span_info.dim
std = sigmas[st]
eq = x[:, st] - torch.tanh(recon_x[:, st])
loss.append((eq ** 2 / 2 / (std ** 2)).sum())
loss.append((eq**2 / 2 / (std**2)).sum())
loss.append(torch.log(std) * x.size()[0])
st = ed

Expand All @@ -99,23 +99,23 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
st = ed

assert st == recon_x.size()[1]
KLD = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())
KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
return sum(loss) * factor / x.size()[0], KLD / x.size()[0]


class TVAE(BatchedSynthesizer):
"""TVAE."""

def __init__(
self,
embedding_dim=128,
compress_dims=(128, 128),
decompress_dims=(128, 128),
l2scale=1e-5,
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True,
self,
embedding_dim=128,
compress_dims=(128, 128),
decompress_dims=(128, 128),
l2scale=1e-5,
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True,
):
super().__init__(batch_size)
self.embedding_dim = embedding_dim
Expand Down
1 change: 1 addition & 0 deletions sdgx/models/components/sdv_rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformers for categorical data."""

import math
import warnings

Expand Down
7 changes: 5 additions & 2 deletions sdgx/models/ml/single_table/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from tqdm import autonotebook as tqdm

import time
from pathlib import Path
from typing import List
Expand All @@ -19,6 +19,7 @@
Sequential,
functional,
)
from tqdm import autonotebook as tqdm

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
Expand Down Expand Up @@ -221,7 +222,9 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **
self._fit(len(self._ndarry_loader))
logger.info("CTGAN training finished.")

def _pre_fit(self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None) -> NDArrayLoader:
def _pre_fit(
self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None
) -> NDArrayLoader:
if not discrete_columns:
discrete_columns = []

Expand Down
Loading

0 comments on commit fb008de

Please sign in to comment.