Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 26, 2024
1 parent 6551756 commit 226ddc5
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 81 deletions.
12 changes: 6 additions & 6 deletions sdgx/data_connectors/dataframe_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

class DataFrameConnector(DataConnector):
def __init__(
self,
df: pd.DataFrame,
*args,
**kwargs,
self,
df: pd.DataFrame,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.df: pd.DataFrame = df
Expand All @@ -24,7 +24,7 @@ def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | Non
if offset >= length:
return None
limit = limit or length
return self.df.iloc[offset:min(offset + limit, length)]
return self.df.iloc[offset : min(offset + limit, length)]

def _columns(self) -> list[str]:
return list(self.df.columns)
Expand All @@ -35,7 +35,7 @@ def generator() -> Generator[pd.DataFrame, None, None]:
if offset < length:
current = offset
while current < length:
yield self.df.iloc[current:min(current + chunksize, length - 1)]
yield self.df.iloc[current : min(current + chunksize, length - 1)]
current += chunksize

return generator()
Expand Down
73 changes: 40 additions & 33 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
from collections.abc import Iterable
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Set, Literal
from typing import Any, Dict, List, Literal, Set

import pandas as pd
from pydantic import BaseModel, Field, field_validator

from sdgx.data_loader import DataLoader
from sdgx.data_models.inspectors.base import RelationshipInspector
from sdgx.data_models.inspectors.manager import InspectorManager
from sdgx.exceptions import MetadataInitError, MetadataInvalidError
from sdgx.utils import logger

CategoricalEncoderType = Literal["onehot", "label"]


class Metadata(BaseModel):
"""
Metadata is mainly used to describe the data types of all columns in a single data table.
Expand Down Expand Up @@ -87,7 +90,11 @@ def check_column_list(cls, value) -> Any:
"""

def check_categorical_threshold(self, num_categories):
return num_categories > self.categorical_threshold if self.categorical_threshold is not None else False
return (
num_categories > self.categorical_threshold
if self.categorical_threshold is not None
else False
)

@property
def tag_fields(self) -> Iterable[str]:
Expand Down Expand Up @@ -115,16 +122,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -184,9 +191,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand Down Expand Up @@ -276,14 +283,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -343,12 +350,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -532,10 +539,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand Down
4 changes: 3 additions & 1 deletion sdgx/data_processors/formatters/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def datetime_formatter(each_value, datetime_format):
datetime_obj = datetime.strptime(str(each_value), datetime_format)
each_stamp = datetime.timestamp(datetime_obj)
except Exception as e:
logger.warning(f"An error occured when convert str to timestamp {e}, we set as mean.")
logger.warning(
f"An error occured when convert str to timestamp {e}, we set as mean."
)
logger.warning(f"Input parameters: ({str(each_value)}, {datetime_format})")
logger.warning(f"Input type: ({type(each_value)}, {type(datetime_format)})")
each_stamp = np.nan
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/components/optimize/ndarray_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def store(self, ndarray: ndarray):
np.save(self._get_cache_filename(self.store_index), ndarray)
self.store_index += 1
else:
for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1):
for ndarray in np.split(ndarray, indices_or_sections=ndarray.shape[1], axis=1):
self.ndarray_list.append(ndarray)
self.store_index += 1

Expand Down
12 changes: 10 additions & 2 deletions sdgx/models/components/optimize/sdv_ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def is_onehot_encoding_column(column_info: SpanInfo):

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
[column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)],
[
column_info[0].dim
for column_info in output_info
if is_onehot_encoding_column(column_info)
],
default=0,
)

Expand All @@ -56,7 +60,11 @@ def is_onehot_encoding_column(column_info: SpanInfo):
self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum(
[column_info[0].dim for column_info in output_info if is_onehot_encoding_column(column_info)]
[
column_info[0].dim
for column_info in output_info
if is_onehot_encoding_column(column_info)
]
)

st = 0
Expand Down
30 changes: 19 additions & 11 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from tqdm import autonotebook as tqdm

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata, CategoricalEncoderType
from sdgx.data_models.metadata import CategoricalEncoderType, Metadata
from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader
from sdgx.models.components.sdv_rdt.transformers import (
ClusterBasedNormalizer,
OneHotEncoder,
)
from sdgx.models.components.sdv_rdt.transformers.categorical import LabelEncoder, NormalizedLabelEncoder
from sdgx.models.components.sdv_rdt.transformers.categorical import (
LabelEncoder,
NormalizedLabelEncoder,
)
from sdgx.utils import logger

SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"])
Expand Down Expand Up @@ -66,7 +69,7 @@ def _fit_continuous(self, data):
column_name=column_name,
column_type="continuous",
transform=gm,
output_info=[SpanInfo(1, "tanh"), SpanInfo(num_components, "softmax")],
output_info=[SpanInfo(1, "tanh"), SpanInfo(num_components, "softmax")],
output_dimensions=1 + num_components,
)

Expand All @@ -89,9 +92,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
activate_fn = "softmax"

checked = self.metadata.check_categorical_threshold(num_categories)
if encoder_type == 'onehot' or not checked:
if encoder_type == "onehot" or not checked:
pass
elif encoder_type == 'label':
elif encoder_type == "label":
encoder = NormalizedLabelEncoder(order_by="alphabetical")
encoder.fit(data, column_name)
num_categories = 1
Expand Down Expand Up @@ -126,9 +129,14 @@ def fit(self, data_loader: DataLoader, discrete_columns=()):
# or column_name in self.metadata.label_columns
logger.debug(f"Fitting discrete column {column_name}...")

column_transform_info = self._fit_discrete(data_loader[[column_name]],
self.metadata.categorical_encoder[
column_name] if column_name in self.metadata.categorical_encoder else 'onehot')
column_transform_info = self._fit_discrete(
data_loader[[column_name]],
(
self.metadata.categorical_encoder[column_name]
if column_name in self.metadata.categorical_encoder
else "onehot"
),
)
else:
logger.debug(f"Fitting continuous column {column_name}...")
column_transform_info = self._fit_continuous(data_loader[[column_name]])
Expand Down Expand Up @@ -198,7 +206,7 @@ def _parallel_transform(self, raw_data, column_transform_info_list) -> NDArrayLo

loader = NDArrayLoader(save_to_file=False)
for ndarray in tqdm.tqdm(
p(processes), desc="Transforming data", total=len(processes), delay=3
p(processes), desc="Transforming data", total=len(processes), delay=3
):
loader.store(ndarray.astype(float))
return loader
Expand Down Expand Up @@ -244,10 +252,10 @@ def inverse_transform(self, data, sigmas=None):
column_names = []

for column_transform_info in tqdm.tqdm(
self._column_transform_info_list, desc="Inverse transforming", delay=3
self._column_transform_info_list, desc="Inverse transforming", delay=3
):
dim = column_transform_info.output_dimensions
column_data = data[:, st: st + dim]
column_data = data[:, st : st + dim]
if column_transform_info.column_type == "continuous":
recovered_column_data = self._inverse_transform_continuous(
column_transform_info, column_data, sigmas, st
Expand Down
16 changes: 9 additions & 7 deletions sdgx/models/components/sdv_ctgan/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __getstate__(self):

random_states = self.random_states
if (
isinstance(random_states, tuple)
and isinstance(random_states[0], np.random.RandomState)
and isinstance(random_states[1], torch.Generator)
isinstance(random_states, tuple)
and isinstance(random_states[0], np.random.RandomState)
and isinstance(random_states[1], torch.Generator)
):
state["_numpy_random_state"] = random_states[0].get_state()
state["_torch_random_state"] = random_states[1].get_state()
Expand Down Expand Up @@ -112,7 +112,9 @@ def save(self, path):
self.set_device(device_backup)

@classmethod
def load(cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu"):
def load(
cls, path: Union[str, Path], device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""Load the model stored in the passed arg `path`."""
with open(path, "rb") as f:
model = cloudpickle.load(f)
Expand All @@ -135,9 +137,9 @@ def set_random_state(self, random_state):
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
Expand Down
5 changes: 3 additions & 2 deletions sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BaseSynthesizer,
random_state, BatchedSynthesizer,
BatchedSynthesizer,
random_state,
)


Expand Down Expand Up @@ -262,7 +263,7 @@ def _cond_loss(self, data, c, m):
st_c = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if len(column_info) != 1 or span_info.activation_fn != "softmax": # todo 待修改
if len(column_info) != 1 or span_info.activation_fn != "softmax": # todo 待修改
# not discrete column
st += span_info.dim
else:
Expand Down
24 changes: 12 additions & 12 deletions sdgx/models/components/sdv_ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BatchedSynthesizer,
random_state,
BatchedSynthesizer
)


Expand Down Expand Up @@ -85,7 +85,7 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
ed = st + span_info.dim
std = sigmas[st]
eq = x[:, st] - torch.tanh(recon_x[:, st])
loss.append((eq ** 2 / 2 / (std ** 2)).sum())
loss.append((eq**2 / 2 / (std**2)).sum())
loss.append(torch.log(std) * x.size()[0])
st = ed

Expand All @@ -99,23 +99,23 @@ def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
st = ed

assert st == recon_x.size()[1]
KLD = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())
KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
return sum(loss) * factor / x.size()[0], KLD / x.size()[0]


class TVAE(BatchedSynthesizer):
"""TVAE."""

def __init__(
self,
embedding_dim=128,
compress_dims=(128, 128),
decompress_dims=(128, 128),
l2scale=1e-5,
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True,
self,
embedding_dim=128,
compress_dims=(128, 128),
decompress_dims=(128, 128),
l2scale=1e-5,
batch_size=500,
epochs=300,
loss_factor=2,
cuda=True,
):
super().__init__(batch_size)
self.embedding_dim = embedding_dim
Expand Down
Loading

0 comments on commit 226ddc5

Please sign in to comment.