Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cyantangerine committed Nov 29, 2024
1 parent 2c7a565 commit b780806
Show file tree
Hide file tree
Showing 14 changed files with 324 additions and 139 deletions.
2 changes: 1 addition & 1 deletion example/sdgx_example_metadata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@
" \"industry\": \"frequency\",\n",
" \"work_year\": \"label\"\n",
" # \"work_type\" using default encoder, we are not specified its encoder.\n",
" # \"issue_date\" and \"earlies_credit_mon\" are datetime columns. We are not specified its encoder.\n",
" # \"issue_date\" and \"earlies_credit_mon\" are datetime columns. We not need to specify its encoder when we use DatetimeFormatter in training, because it transformed as float. \n",
"}"
],
"outputs": [],
Expand Down
146 changes: 113 additions & 33 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def check_column_list(cls, value) -> Any:
"""

def get_column_encoder_by_categorical_threshold(
self, num_categories: int
self, num_categories: int
) -> Union[CategoricalEncoderType, None]:
encoder_type = None
if self.categorical_threshold is None:
Expand Down Expand Up @@ -135,16 +135,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -204,9 +204,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand All @@ -215,7 +215,13 @@ def set(self, key: str, value: Any):
if isinstance(old_value, Iterable) and not isinstance(old_value, str):
# list, set, tuple...
value = value if isinstance(value, Iterable) and not isinstance(value, str) else [value]
value = type(old_value)(value)
try:
value = type(old_value)(value)
except TypeError as e:
if type(old_value) == defaultdict:
value = dict(value)
else:
raise e

if key in self.model_fields:
setattr(self, key, value)
Expand Down Expand Up @@ -296,14 +302,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -363,12 +369,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -552,10 +558,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand All @@ -575,3 +581,77 @@ def get_column_pii(self, column_name: str):
if column_name in self.pii_columns:
return True
return False

def change_column_type(self, column_names: str | List[str], column_original_type: str, column_new_type: str):
"""Change the type of column."""
if not column_names:
return
if isinstance(column_names, str):
column_names = [column_names]
all_fields = list(self.tag_fields)
original_type = f"{column_original_type}_columns"
new_type = f"{column_new_type}_columns"
if original_type not in all_fields:
raise MetadataInvalidError(f"Column type {column_original_type} not exist in metadata.")
if new_type not in all_fields:
raise MetadataInvalidError(f"Column type {column_new_type} not exist in metadata.")
type_columns = self.get(original_type)
diff = set(column_names).difference(type_columns)
if diff:
raise MetadataInvalidError(f"Columns {column_names} not exist in {original_type}.")
self.add(new_type, column_names)
type_columns = type_columns.difference(column_names)
self.set(original_type, type_columns)

def remove_column(self, column_names: List[str] | str):
"""
Remove a column from all columns type.
Args:
column_names: List[str]: To removed columns name list.
"""
if not column_names:
return
if isinstance(column_names, str):
column_names = [column_names]
column_names = frozenset(column_names)
inter = column_names.intersection(self.column_list)
if not inter:
raise MetadataInvalidError(f"Columns {inter} not exist in metadata.")

def do_remove_columns(key, get=True, to_removes=column_names):
obj = self
if get:
target = obj.get(key)
else:
target = getattr(obj, key)
res = None
if isinstance(target, list):
res = [item for item in target if item not in to_removes]
elif isinstance(target, dict):
if key == "numeric_format":
obj.set(key, {
k: {
v2 for v2 in v if v2 not in to_removes
} for k, v in target.items()
})
else:
res = {
k: v for k, v in target.items()
if k not in to_removes
}
elif isinstance(target, set):
res = target.difference(to_removes)

if res is not None:
if get:
obj.set(key, res)
else:
setattr(obj, key, res)

to_remove_attribute = list(self.tag_fields)
to_remove_attribute.extend(list(self.format_fields))
for attr in to_remove_attribute:
do_remove_columns(attr)
for attr in ["column_list", "primary_keys"]:
do_remove_columns(attr, False)
self.check()
39 changes: 19 additions & 20 deletions sdgx/data_processors/formatters/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
f"Column {each_col} has no datetime_format, DatetimeFormatter will REMOVE this column!"
)

# Remove successful formatted datetime columns from metadata.discrete_columns
if not (set(datetime_columns) - set(metadata.discrete_columns)):
metadata.change_column_type(datetime_columns, "discrete", "datetime")
# Remove dead_columns from metadata
metadata.remove_column(dead_columns)

self.datetime_columns = datetime_columns
self.dead_columns = dead_columns

Expand Down Expand Up @@ -189,25 +195,17 @@ def convert_timestamp_to_datetime(timestamp_column_list, format_dict, processed_
Returns:
- result_data (pd.DataFrame): DataFrame with timestamp columns converted to datetime format.
"""

def convert_single_column_timestamp_to_str(column_data: pd.Series, datetime_format: str):
"""
convert each single column timestamp(int) to datetime string.
"""
res = []
for each_stamp in column_data:
try:
each_str = datetime.fromtimestamp(each_stamp).strftime(datetime_format)
except Exception as e:
logger.debug(f"An error occured when convert timestamp to str {e}.")
each_str = "No Datetime"

res.append(each_str)
res = pd.Series(res)
res = res.astype(str)
return res
TODO:
if the value <0, the result will be `No Datetime`, try to fix it.
"""
def column_timestamp_formatter(each_stamp: int, timestamp_format: str) -> str:
try:
each_str = datetime.fromtimestamp(each_stamp).strftime(timestamp_format)
except Exception as e:
logger.debug(f"An error occured when convert timestamp to str {e}.")
each_str = "No Datetime"
return each_str

# Copy the processed data to result_data
result_data = processed_data.copy()
Expand All @@ -217,8 +215,9 @@ def convert_single_column_timestamp_to_str(column_data: pd.Series, datetime_form
# Check if the column is in the DataFrame
if column in result_data.columns:
# Convert the timestamp to datetime format using the format provided in datetime_column_dict
result_data[column] = convert_single_column_timestamp_to_str(
result_data[column], format_dict[column]
result_data[column] = result_data[column].apply(
column_timestamp_formatter,
timestamp_format=format_dict[column]
)
else:
logger.error(f"Column {column} not in processed data's column list!")
Expand Down
6 changes: 5 additions & 1 deletion sdgx/models/components/optimize/ndarray_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ def get_all(self) -> ndarray:
return np.concatenate([array for array in self.iter()], axis=1)

@cached_property
def __shape_0(self):
return self.load(0).shape[0]

@property
def shape(self) -> tuple[int, int]:
return self.load(0).shape[0], self.store_index
return self.__shape_0, self.store_index

def __len__(self):
return self.shape[0]
Expand Down
1 change: 1 addition & 0 deletions sdgx/models/components/optimize/sdv_ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def is_onehot_encoding_column(column_info: SpanInfo):
st = ed
else:
st += sum([span_info.dim for span_info in column_info])
assert st == data.shape[1]

def _random_choice_prob_index(self, discrete_column_id):
probs = self._discrete_column_category_prob[discrete_column_id]
Expand Down
17 changes: 9 additions & 8 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections import namedtuple
from typing import List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -102,8 +103,8 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
encoder.fit(data, column_name)
num_categories = len(encoder.dummies)
activate_fn = "softmax"
# if selected_encoder_type is not specified or using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type or self.metadata and num_categories != -1:
# if selected_encoder_type is not specified and using onehot num_categories > threshold, change the encoder.
if not selected_encoder_type and self.metadata and num_categories != -1:
encoder_type = (
self.metadata.get_column_encoder_by_categorical_threshold(num_categories)
or encoder_type
Expand Down Expand Up @@ -145,12 +146,12 @@ def fit(self, data_loader: DataLoader, discrete_columns=()):
This step also counts the #columns in matrix data and span information.
"""
self.output_info_list = []
self.output_dimensions = 0
self.dataframe = True
self.output_info_list: List[SpanInfo] = []
self.output_dimensions: int = 0
self.dataframe: bool = True

self._column_raw_dtypes = data_loader[: data_loader.chunksize].infer_objects().dtypes
self._column_transform_info_list = []
self._column_transform_info_list: List[ColumnTransformInfo] = []
for column_name in tqdm.tqdm(data_loader.columns(), desc="Preparing data", delay=3):
if column_name in discrete_columns:
# or column_name in self.metadata.label_columns
Expand Down Expand Up @@ -183,8 +184,8 @@ def _transform_continuous(self, column_transform_info, data):

def _transform_discrete(self, column_transform_info, data):
logger.debug(f"Transforming discrete column {column_transform_info.column_name}...")
ohe = column_transform_info.transform
return ohe.transform(data).to_numpy()
encoder = column_transform_info.transform
return encoder.transform(data).to_numpy()

def _synchronous_transform(self, raw_data, column_transform_info_list) -> NDArrayLoader:
"""Take a Pandas DataFrame and transform columns synchronous.
Expand Down
3 changes: 1 addition & 2 deletions sdgx/models/components/sdv_ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ def _cond_loss(self, data, c, m):
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if len(column_info) != 1 or span_info.activation_fn != "softmax":
# TODO may need to revise for label encoder?
# not discrete column
# not onehot column
st += span_info.dim
else:
ed = st + span_info.dim
Expand Down
7 changes: 5 additions & 2 deletions sdgx/models/components/sdv_rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _reverse_transform_by_row(self, data):
"""Reverse transform the data by iterating over each row."""
return data.apply(self._get_category_from_start).astype(self.dtype)

def _reverse_transform(self, data):
def _reverse_transform(self, data, normalize=False):
"""Convert float values back to the original categorical values.
Args:
Expand All @@ -241,7 +241,7 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
data = data.clip(0, 1)
data = data.clip(-1 if normalize else 0, 1)
num_rows = len(data)
num_categories = len(self.means)

Expand Down Expand Up @@ -675,3 +675,6 @@ def _fit(self, data):
"""
self.dtype = data.dtype
self.intervals, self.means, self.starts = self._get_intervals(data, normalized=True)

def _reverse_transform(self, data):
return super()._reverse_transform(data, True)
Loading

0 comments on commit b780806

Please sign in to comment.