Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 26, 2024
1 parent fb008de commit 3d9af8f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Iterable
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Set, Literal, Union
from typing import Any, Dict, List, Literal, Set, Union

import pandas as pd
from pydantic import BaseModel, Field, field_validator
Expand Down
3 changes: 2 additions & 1 deletion sdgx/models/components/optimize/ndarray_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, cache_root: str | Path = DEFAULT_CACHE_ROOT, save_to_file=Tru
def get_auto_save(raw_data) -> NDArrayLoader:
save_to_file = True
if isinstance(raw_data, pd.DataFrame) or (
isinstance(raw_data, DataLoader) and isinstance(raw_data.data_connector, DataFrameConnector)
isinstance(raw_data, DataLoader)
and isinstance(raw_data.data_connector, DataFrameConnector)
):
save_to_file = False
return NDArrayLoader(save_to_file=save_to_file)
Expand Down
13 changes: 8 additions & 5 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
ClusterBasedNormalizer,
OneHotEncoder,
)
from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedLabelEncoder
from sdgx.models.components.sdv_rdt.transformers.categorical import (
NormalizedLabelEncoder,
)
from sdgx.utils import logger

SpanInfo = namedtuple("SpanInfo", ["dim", "activation_fn"])
Expand Down Expand Up @@ -126,11 +128,12 @@ def fit(self, data_loader: DataLoader, discrete_columns=()):
# or column_name in self.metadata.label_columns
logger.debug(f"Fitting discrete column {column_name}...")
encoder_type = None
if self.metadata.categorical_encoder and column_name in self.metadata.categorical_encoder:
if (
self.metadata.categorical_encoder
and column_name in self.metadata.categorical_encoder
):
encoder_type = self.metadata.categorical_encoder[column_name]
column_transform_info = self._fit_discrete(
data_loader[[column_name]], encoder_type
)
column_transform_info = self._fit_discrete(data_loader[[column_name]], encoder_type)
else:
logger.debug(f"Fitting continuous column {column_name}...")
column_transform_info = self._fit_continuous(data_loader[[column_name]])
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/ml/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def save(self, save_dir: str | Path):
return SDVBaseSynthesizer.save(self, save_dir / self.MODEL_SAVE_NAME)

@classmethod
def load(cls, save_dir: str | Path, device: str=None) -> "CTGANSynthesizerModel":
def load(cls, save_dir: str | Path, device: str = None) -> "CTGANSynthesizerModel":
return SDVBaseSynthesizer.load(save_dir / cls.MODEL_SAVE_NAME, device)

@staticmethod
Expand Down

0 comments on commit 3d9af8f

Please sign in to comment.