From 2065b0bc3819796a15091c154fd5123530208a7b Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Fri, 19 Apr 2024 15:14:47 -0500 Subject: [PATCH 01/32] =?UTF-8?q?Bump=20version:=201.12.1=20=E2=86=92=201.?= =?UTF-8?q?12.2.dev0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- sdv/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3c3a75717..901c68965 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,7 +154,7 @@ namespaces = false version = {attr = 'sdv.__version__'} [tool.bumpversion] -current_version = "1.12.1" +current_version = "1.12.2.dev0" parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' serialize = [ '{major}.{minor}.{patch}.{release}{candidate}', diff --git a/sdv/__init__.py b/sdv/__init__.py index e3a60df2d..ef209a09a 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = 'DataCebo, Inc.' __email__ = 'info@sdv.dev' -__version__ = '1.12.1' +__version__ = '1.12.2.dev0' import sys From 1be52f6af74670fef2fe93ef047875def574036a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 19 Apr 2024 16:12:15 -0500 Subject: [PATCH 02/32] Latest Code Analysis (#1940) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- static_code_analysis.txt | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/static_code_analysis.txt b/static_code_analysis.txt index 74bcc20da..3efa21ea2 100644 --- a/static_code_analysis.txt +++ b/static_code_analysis.txt @@ -1,15 +1,15 @@ -Run started:2024-04-16 22:26:14.110085 +Run started:2024-04-19 21:07:50.508524 Test results: >> Issue: [B110:try_except_pass] Try, Except, Pass detected. Severity: Low Confidence: High CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b110_try_except_pass.html - Location: ./sdv/_utils.py:320:8 -319 -320 except Exception: -321 pass -322 + Location: ./sdv/_utils.py:321:8 +320 +321 except Exception: +322 pass +323 -------------------------------------------------- >> Issue: [B105:hardcoded_password_string] Possible hardcoded password: '#' @@ -57,7 +57,7 @@ Test results: -------------------------------------------------- Code scanned: - Total lines of code: 10878 + Total lines of code: 10909 Total lines skipped (#nosec): 0 Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0 From 1dd84259340a6ae7a9a7e83c85cbae21c5886b2a Mon Sep 17 00:00:00 2001 From: SDV Team <98988753+sdv-team@users.noreply.github.com> Date: Mon, 22 Apr 2024 09:06:52 -0400 Subject: [PATCH 03/32] Automated Latest Dependency Updates (#1942) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- latest_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latest_requirements.txt b/latest_requirements.txt index 725809553..6ea9b9347 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -5,6 +5,6 @@ deepecho==0.6.0 graphviz==0.20.3 numpy==1.26.4 pandas==2.2.2 -rdt==1.11.1 +rdt==1.12.0 sdmetrics==0.14.0 tqdm==4.66.2 From 63cdaabc722c0f0903b39ade3e2359505a517aaf Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Mon, 22 Apr 2024 08:21:56 -0700 Subject: [PATCH 04/32] Fix warning (#1941) --- sdv/__init__.py | 2 +- tests/unit/test___init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index ef209a09a..409cb7735 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -94,7 +94,7 @@ def _find_addons(): addon = entry_point.load() except Exception as e: # pylint: disable=broad-exception-caught msg = ( - f'Failed to load "{entry_point.name}" from "{entry_point.version}" ' + f'Failed to load "{entry_point.name}" from "{entry_point.value}" ' f'with error:\n{e}' ) warnings.warn(msg) diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index e94f3b214..6a4b1edfa 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -64,7 +64,7 @@ def entry_point_error(): bad_entry_point = Mock() bad_entry_point.name = 'bad_entry_point' - bad_entry_point.version = 'bad_module' + bad_entry_point.value = 'bad_module' bad_entry_point.load.side_effect = entry_point_error entry_points_mock.return_value = [bad_entry_point] msg = 'Failed to load "bad_entry_point" from "bad_module" with error:\nbad value' From 856baa2ccac6669535938bc5cdabf437fdb03892 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:58:45 +0100 Subject: [PATCH 05/32] Switch parameter order in drop_unknown_references (#1955) --- sdv/utils/poc.py | 6 +++--- tests/integration/utils/test_poc.py | 8 ++++---- tests/unit/utils/test_poc.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index da36c6653..deb5dfba9 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -12,15 +12,15 @@ _print_simplified_schema_summary, _simplify_data, _simplify_metadata) -def drop_unknown_references(metadata, data, drop_missing_values=True, verbose=True): +def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): """Drop rows with unknown foreign keys. Args: - metadata (MultiTableMetadata): - Metadata of the datasets. data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. drop_missing_values (bool): Boolean describing whether or not to also drop foreign keys with missing values If True, drop rows with missing values in the foreign keys. diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index e8189b6ea..981b8798e 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -76,7 +76,7 @@ def test_drop_unknown_references(metadata, data, capsys): with pytest.raises(InvalidDataError, match=expected_message): metadata.validate_data(data) - cleaned_data = drop_unknown_references(metadata, data) + cleaned_data = drop_unknown_references(data, metadata) metadata.validate_data(cleaned_data) captured = capsys.readouterr() @@ -100,7 +100,7 @@ def test_drop_unknown_references_valid_data(metadata, data, capsys): data['child'].loc[4, 'parent_id'] = 2 # Run - result = drop_unknown_references(metadata, data) + result = drop_unknown_references(data, metadata) captured = capsys.readouterr() # Assert @@ -122,7 +122,7 @@ def test_drop_unknown_references_drop_missing_values(metadata, data, capsys): data['child'].loc[4, 'parent_id'] = np.nan # Run - cleaned_data = drop_unknown_references(metadata, data) + cleaned_data = drop_unknown_references(data, metadata) metadata.validate_data(cleaned_data) captured = capsys.readouterr() @@ -146,7 +146,7 @@ def test_drop_unknown_references_not_drop_missing_values(metadata, data): # Run cleaned_data = drop_unknown_references( - metadata, data, drop_missing_values=False, verbose=False + data, metadata, drop_missing_values=False, verbose=False ) # Assert diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index a18224460..ba93205eb 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -65,7 +65,7 @@ def test_drop_unknown_references(mock_get_rows_to_drop, mock_stdout_write): }) # Run - result = drop_unknown_references(metadata, data) + result = drop_unknown_references(data, metadata) # Assert expected_pattern = re.compile( @@ -127,7 +127,7 @@ def test_drop_unknown_references_valid_data_mock(mock_stdout_write): } # Run - result = drop_unknown_references(metadata, data) + result = drop_unknown_references(data, metadata) # Assert expected_pattern = re.compile( @@ -198,7 +198,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r }) # Run - result = drop_unknown_references(metadata, data, verbose=False) + result = drop_unknown_references(data, metadata, verbose=False) # Assert metadata.validate.assert_called_once() @@ -278,7 +278,7 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop }) # Run - result = drop_unknown_references(metadata, data, drop_missing_values=False, verbose=False) + result = drop_unknown_references(data, metadata, drop_missing_values=False, verbose=False) # Assert mock_get_rows_to_drop.assert_called_once() @@ -360,7 +360,7 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): 'Try providing different data for this table.' ) with pytest.raises(InvalidDataError, match=expected_message): - drop_unknown_references(metadata, data) + drop_unknown_references(data, metadata) @patch('sdv.utils.poc._get_total_estimated_columns') From 71ab960067ac35d217625550b1c65795e6355794 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 23 Apr 2024 12:49:49 -0400 Subject: [PATCH 06/32] Only run unit and integration tests on oldest and latest python versions for macos (#1957) --- .github/workflows/integration.yml | 7 ++++++- .github/workflows/minimum.yml | 7 ++++++- .github/workflows/unit.yml | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 99dae2eb8..3418bf2e7 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -11,7 +11,12 @@ jobs: strategy: matrix: python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12'] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + include: + - os: macos-latest + python-version: '3.8' + - os: macos-latest + python-version: '3.12' steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 616b9a10d..f97dc4fbe 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -11,7 +11,12 @@ jobs: strategy: matrix: python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12'] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + include: + - os: macos-latest + python-version: '3.8' + - os: macos-latest + python-version: '3.12' steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 328704e00..870a0f28d 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -11,7 +11,12 @@ jobs: strategy: matrix: python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12'] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + include: + - os: macos-latest + python-version: '3.8' + - os: macos-latest + python-version: '3.12' steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} From 17c0d09adbcc298feded1db1fee53535367b6e0d Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:39:23 +0100 Subject: [PATCH 07/32] Add get_table_metadata function (#1956) --- sdv/metadata/multi_table.py | 14 ++++++++ .../integration/metadata/test_multi_table.py | 32 ++++++++++++++++++- tests/unit/metadata/test_multi_table.py | 16 ++++++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 194dc7ba7..ff188d704 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -896,6 +896,20 @@ def get_column_names(self, table_name, **kwargs): self._validate_table_exists(table_name) return self.tables[table_name].get_column_names(**kwargs) + def get_table_metadata(self, table_name): + """Return the metadata for a table. + + Args: + table_name (str): + The name of the table to get the metadata for. + + Returns: + SingleTableMetadata: + The metadata for the given table. + """ + self._validate_table_exists(table_name) + return deepcopy(self.tables[table_name]) + def visualize(self, show_table_details='full', show_relationship_labels=True, output_filepath=None): """Create a visualization of the multi-table dataset. diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index 10982ce53..8a7b5d56d 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -4,7 +4,7 @@ from unittest.mock import patch from sdv.datasets.demo import download_demo -from sdv.metadata import MultiTableMetadata +from sdv.metadata import MultiTableMetadata, SingleTableMetadata from tests.utils import get_multi_table_metadata @@ -377,3 +377,33 @@ def test_get_column_names(): # Assert assert set(matches) == {'upravna_enota', 'id_nesreca'} + + +def test_get_table_metadata(): + """Test the ``get_table_metadata`` method.""" + # Setup + metadata = get_multi_table_metadata() + metadata.add_column('nesreca', 'latitude', sdtype='latitude') + metadata.add_column('nesreca', 'longitude', sdtype='longitude') + metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude']) + + # Run + table_metadata = metadata.get_table_metadata('nesreca') + + # Assert + assert isinstance(table_metadata, SingleTableMetadata) + expected_metadata = { + 'primary_key': 'id_nesreca', + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'columns': { + 'upravna_enota': {'sdtype': 'id'}, + 'id_nesreca': {'sdtype': 'id'}, + 'nesreca_val': {'sdtype': 'numerical'}, + 'latitude': {'sdtype': 'latitude', 'pii': True}, + 'longitude': {'sdtype': 'longitude', 'pii': True} + }, + 'column_relationships': [ + {'type': 'gps', 'column_names': ['latitude', 'longitude']} + ] + } + assert table_metadata.to_dict() == expected_metadata diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index b6089bdf0..7f53f0cfc 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -2188,6 +2188,22 @@ def test_get_column_names(self): # Assert table1.get_column_names.assert_called_once_with(sdtype='email', pii=True) + @patch('sdv.metadata.multi_table.deepcopy') + def test_get_table_metadata(self, deepcopy_mock): + """Test the ``get_table_metadata`` method.""" + # Setup + metadata = MultiTableMetadata() + metadata._validate_table_exists = Mock() + table1 = Mock() + metadata.tables = {'table1': table1} + + # Run + metadata.get_table_metadata('table1') + + # Assert + metadata._validate_table_exists.assert_called_once_with('table1') + deepcopy_mock.assert_called_once_with(table1) + def test__detect_relationships(self): """Test relationships are automatically detected and the foreign key sdtype is updated.""" # Setup From 275955dcaaa81dcf3385b03965839843b80c255b Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:55:25 +0200 Subject: [PATCH 08/32] Add usage logging (#1920) --- Makefile | 8 +- latest_requirements.txt | 1 + pyproject.toml | 1 + sdv/__init__.py | 5 +- sdv/logging/__init__.py | 9 ++ sdv/logging/sdv_logger_config.yml | 27 ++++ sdv/logging/utils.py | 98 ++++++++++++ sdv/metadata/multi_table.py | 19 +++ sdv/metadata/single_table.py | 12 ++ sdv/multi_table/base.py | 123 +++++++++++++-- sdv/multi_table/hma.py | 18 ++- sdv/single_table/base.py | 93 ++++++++++- sdv/single_table/copulagan.py | 4 +- sdv/single_table/copulas.py | 4 +- sdv/single_table/ctgan.py | 9 +- tests/integration/multi_table/test_hma.py | 85 +++++++++- tests/integration/single_table/test_base.py | 78 ++++++++++ tests/unit/logging/__init__.py | 0 tests/unit/logging/test_utils.py | 85 ++++++++++ tests/unit/metadata/test_multi_table.py | 18 ++- tests/unit/metadata/test_single_table.py | 20 ++- tests/unit/multi_table/test_base.py | 163 ++++++++++++++++---- tests/unit/multi_table/test_hma.py | 10 +- tests/unit/single_table/test_base.py | 147 ++++++++++++++---- tests/unit/single_table/test_copulagan.py | 1 + tests/unit/single_table/test_copulas.py | 3 +- tests/unit/single_table/test_ctgan.py | 2 + tests/utils.py | 17 ++ 28 files changed, 951 insertions(+), 109 deletions(-) create mode 100644 sdv/logging/__init__.py create mode 100644 sdv/logging/sdv_logger_config.yml create mode 100644 sdv/logging/utils.py create mode 100644 tests/unit/logging/__init__.py create mode 100644 tests/unit/logging/test_utils.py diff --git a/Makefile b/Makefile index f31c000d1..c953b15b6 100644 --- a/Makefile +++ b/Makefile @@ -123,12 +123,8 @@ test-integration: ## run tests quickly with the default Python test-readme: ## run the readme snippets invoke readme -.PHONY: test-tutorials -test-tutorials: ## run the tutorial notebooks - invoke tutorials - .PHONY: test -test: test-unit test-integration test-readme test-tutorials ## test everything that needs test dependencies +test: test-unit test-integration test-readme ## test everything that needs test dependencies .PHONY: test-all test-all: ## run tests on every Python version with tox @@ -265,5 +261,5 @@ release-major: check-release bumpversion-major release .PHONY: check-deps check-deps: - $(eval allow_list='cloudpickle=|graphviz=|numpy=|pandas=|tqdm=|copulas=|ctgan=|deepecho=|rdt=|sdmetrics=') + $(eval allow_list='cloudpickle=|graphviz=|numpy=|pandas=|tqdm=|copulas=|ctgan=|deepecho=|rdt=|sdmetrics=|platformdirs=') pip freeze | grep -v "SDV.git" | grep -E $(allow_list) | sort > $(OUTPUT_FILEPATH) diff --git a/latest_requirements.txt b/latest_requirements.txt index 6ea9b9347..7c9f63bac 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -8,3 +8,4 @@ pandas==2.2.2 rdt==1.12.0 sdmetrics==0.14.0 tqdm==4.66.2 +platformdirs==4.2.0 diff --git a/pyproject.toml b/pyproject.toml index 901c68965..8e75e0b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ 'deepecho>=0.6.0', 'rdt>=1.12.0', 'sdmetrics>=0.14.0', + 'platformdirs>=4.0' ] [project.urls] diff --git a/sdv/__init__.py b/sdv/__init__.py index 409cb7735..4c29c84de 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -16,8 +16,8 @@ from types import ModuleType from sdv import ( - constraints, data_processing, datasets, evaluation, io, lite, metadata, metrics, multi_table, - sampling, sequential, single_table, version) + constraints, data_processing, datasets, evaluation, io, lite, logging, metadata, metrics, + multi_table, sampling, sequential, single_table, version) __all__ = [ 'constraints', @@ -26,6 +26,7 @@ 'evaluation', 'io', 'lite', + 'logging', 'metadata', 'metrics', 'multi_table', diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py new file mode 100644 index 000000000..436a1a442 --- /dev/null +++ b/sdv/logging/__init__.py @@ -0,0 +1,9 @@ +"""Module for configuring loggers within the SDV library.""" + +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config + +__all__ = ( + 'disable_single_table_logger', + 'get_sdv_logger', + 'get_sdv_logger_config', +) diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml new file mode 100644 index 000000000..64104495f --- /dev/null +++ b/sdv/logging/sdv_logger_config.yml @@ -0,0 +1,27 @@ +log_registry: 'local' +version: 1 +loggers: + SingleTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + MultiTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + MultiTableMetadata: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + SingleTableMetadata: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py new file mode 100644 index 000000000..2f6a13be4 --- /dev/null +++ b/sdv/logging/utils.py @@ -0,0 +1,98 @@ +"""Utilities for configuring logging within the SDV library.""" + +import contextlib +import logging +from functools import lru_cache +from pathlib import Path + +import platformdirs +import yaml + + +def get_sdv_logger_config(): + """Return a dictionary with the logging configuration.""" + logging_path = Path(__file__).parent + with open(logging_path / 'sdv_logger_config.yml', 'r') as f: + logger_conf = yaml.safe_load(f) + + # Logfile to be in this same directory + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + store_path.mkdir(parents=True, exist_ok=True) + for logger in logger_conf.get('loggers', {}).values(): + handler = logger.get('handlers', {}) + if handler.get('filename') == 'sdv_logs.log': + handler['filename'] = store_path / handler['filename'] + + return logger_conf + + +@contextlib.contextmanager +def disable_single_table_logger(): + """Temporarily disables logging for the single table synthesizers. + + This context manager temporarily removes all handlers associated with + the ``SingleTableSynthesizer`` logger, disabling logging for that module + within the current context. After the context exits, the + removed handlers are restored to the logger. + """ + # Logging without ``SingleTableSynthesizer`` + single_table_logger = logging.getLogger('SingleTableSynthesizer') + handlers = single_table_logger.handlers + single_table_logger.handlers = [] + try: + yield + finally: + for handler in handlers: + single_table_logger.addHandler(handler) + + +@lru_cache() +def get_sdv_logger(logger_name): + """Get a logger instance with the specified name and configuration. + + This function retrieves or creates a logger instance with the specified name + and applies configuration settings based on the logger's name and the logging + configuration. + + Args: + logger_name (str): + The name of the logger to retrieve or create. + + Returns: + logging.Logger: + A logger instance configured according to the logging configuration + and the specific settings for the given logger name. + """ + logger_conf = get_sdv_logger_config() + if logger_conf.get('log_registry') is None: + # Return a logger without any extra settings and avoid writing into files or other streams + return logging.getLogger(logger_name) + + if logger_conf.get('log_registry') == 'local': + logger = logging.getLogger(logger_name) + if logger_name in logger_conf.get('loggers'): + formatter = None + config = logger_conf.get('loggers').get(logger_name) + log_level = getattr(logging, config.get('level', 'INFO')) + if config.get('format'): + formatter = logging.Formatter(config.get('format')) + + logger.setLevel(log_level) + logger.propagate = config.get('propagate', False) + handler = config.get('handlers') + handlers = handler.get('class') + handlers = [handlers] if isinstance(handlers, str) else handlers + for handler_class in handlers: + if handler_class == 'logging.FileHandler': + logfile = handler.get('filename') + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): + ch = logging.StreamHandler() + ch.setLevel(log_level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index ff188d704..b7ccb7046 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1,5 +1,6 @@ """Multi Table Metadata.""" +import datetime import json import logging import warnings @@ -11,6 +12,7 @@ from sdv._utils import _cast_to_iterable, _load_data_from_csv from sdv.errors import InvalidDataError +from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.single_table import SingleTableMetadata @@ -19,6 +21,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) +MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata') WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format'] @@ -1054,6 +1057,22 @@ def save_to_json(self, filepath): """ validate_file_does_not_exist(filepath) metadata = self.to_dict() + total_columns = 0 + for table in self.tables.values(): + total_columns += len(table.columns) + + MULTITABLEMETADATA_LOGGER.info( + '\nMetadata Save:\n' + ' Timestamp: %s\n' + ' Statistics about the metadata:\n' + ' Total number of tables: %s\n' + ' Total number of columns: %s\n' + ' Total number of relationships: %s', + datetime.datetime.now(), + len(self.tables), + total_columns, + len(self.relationships) + ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index ce81ad3a0..4f8b1db94 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -16,6 +16,7 @@ _cast_to_iterable, _format_invalid_values_string, _get_datetime_format, _is_boolean_type, _is_datetime_type, _is_numerical_type, _load_data_from_csv, _validate_datetime_format) from sdv.errors import InvalidDataError +from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist @@ -23,6 +24,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) +SINGLETABLEMETADATA_LOGGER = get_sdv_logger('SingleTableMetadata') class SingleTableMetadata: @@ -1206,6 +1208,16 @@ def save_to_json(self, filepath): validate_file_does_not_exist(filepath) metadata = self.to_dict() metadata['METADATA_SPEC_VERSION'] = self.METADATA_SPEC_VERSION + SINGLETABLEMETADATA_LOGGER.info( + '\nMetadata Save:\n' + ' Timestamp: %s\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 1' + ' Total number of columns: %s' + ' Total number of relationships: 0', + datetime.now(), + len(self.columns) + ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 751740b3e..4afbee0e9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -16,8 +16,11 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer +SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') + class BaseMultiTableSynthesizer: """Base class for multi table synthesizers. @@ -56,13 +59,15 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) - self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - **synthesizer_parameters - ) + with disable_single_table_logger(): + for table_name, table_metadata in self.metadata.tables.items(): + synthesizer_parameters = self._table_parameters.get(table_name, {}) + self._table_synthesizers[table_name] = self._synthesizer( + metadata=table_metadata, + locales=self.locales, + table_name=table_name, + **synthesizer_parameters + ) def _get_pbar_args(self, **kwargs): """Return a dictionary with the updated keyword args for a progress bar.""" @@ -113,6 +118,15 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) + SYNTHESIZER_LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def _get_root_parents(self): """Get the set of root parents in the graph.""" @@ -186,6 +200,8 @@ def set_table_parameters(self, table_name, table_parameters): A dictionary with the parameters as keys and the values to be used to instantiate the table's synthesizer. """ + # Ensure that we set the name of the table no matter what + table_parameters.update({'table_name': table_name}) self._table_synthesizers[table_name] = self._synthesizer( metadata=self.metadata.tables[table_name], **table_parameters @@ -371,9 +387,33 @@ def fit_processed_data(self, processed_data): processed_data (dict): Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``. """ + total_rows = 0 + total_columns = 0 + for table in processed_data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + SYNTHESIZER_LOGGER.info( + '\nFit processed data:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data), + total_rows, + total_columns, + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - augmented_data = self._augment_tables(processed_data) - self._model_tables(augmented_data) + with disable_single_table_logger(): + augmented_data = self._augment_tables(processed_data) + self._model_tables(augmented_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = getattr(version, 'public', None) @@ -387,6 +427,28 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ + total_rows = 0 + total_columns = 0 + for table in data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + SYNTHESIZER_LOGGER.info( + '\nFit:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data), + total_rows, + total_columns, + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() @@ -419,9 +481,31 @@ def sample(self, scale=1.0): raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") - with self._set_temp_numpy_seed(): + with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) + total_rows = 0 + total_columns = 0 + for table in sampled_data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + SYNTHESIZER_LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the sample size:\n' + ' Total number of tables: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(sampled_data), + total_rows, + total_columns, + self._synthesizer_id, + ) return sampled_data def get_learned_distributions(self, table_name): @@ -586,6 +670,16 @@ def save(self, filepath): filepath (str): Path where the instance will be serialized. """ + synthesizer_id = getattr(self, '_synthesizer_id', None) + SYNTHESIZER_LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + synthesizer_id + ) with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -609,4 +703,13 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) + SYNTHESIZER_LOGGER.info( + '\nLoad:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) return synthesizer diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 6ce1d37ab..1e9ce6c2a 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -312,9 +312,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: + synthesizer_parameters = self._table_parameters[child_name] + synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - **self._table_parameters[child_name] + **synthesizer_parameters ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -521,7 +523,12 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer_parameters = self._table_parameters[child_name] + synthesizer_parameters.update({'table_name': child_name}) + synthesizer = self._synthesizer( + table_meta, + **synthesizer_parameters + ) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -615,7 +622,12 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() - synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name]) + synthesizer_parameters = self._table_parameters[table_name] + synthesizer_parameters.update({'table_name': table_name}) + synthesizer = self._synthesizer( + table_meta, + **synthesizer_parameters + ) synthesizer._set_parameters(parameters) try: likelihoods[parent_id] = synthesizer._get_likelihood(table_rows) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 0466dbd42..41d271fd1 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -24,9 +24,12 @@ from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging.utils import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path LOGGER = logging.getLogger(__name__) +SYNTHESIZER_LOGGER = get_sdv_logger('SingleTableSynthesizer') + COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 TMP_FILE_NAME = '.sample.csv.temp' @@ -85,7 +88,7 @@ def _check_metadata_updated(self): ) def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US']): + locales=['en_US'], table_name=None): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata self.metadata.validate() @@ -93,11 +96,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding self.locales = locales + self.table_name = table_name self._data_processor = DataProcessor( metadata=self.metadata, enforce_rounding=self.enforce_rounding, enforce_min_max_values=self.enforce_min_max_values, - locales=self.locales + locales=self.locales, + table_name=self.table_name ) self._fitted = False self._random_state_set = False @@ -107,6 +112,15 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) + SYNTHESIZER_LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" @@ -389,6 +403,22 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ + SYNTHESIZER_LOGGER.info( + '\nFit processed data:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data), + len(processed_data.columns), + self._synthesizer_id, + ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) @@ -405,6 +435,22 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ + SYNTHESIZER_LOGGER.info( + '\nFit:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data), + len(data.columns), + self._synthesizer_id, + ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False @@ -420,6 +466,17 @@ def save(self, filepath): filepath (str): Path where the synthesizer instance will be serialized. """ + synthesizer_id = getattr(self, '_synthesizer_id', None) + SYNTHESIZER_LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + synthesizer_id + ) + with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -443,6 +500,17 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) + if synthesizer.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nLoad:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) + return synthesizer @@ -806,11 +874,12 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file pandas.DataFrame: Sampled data. """ + sample_timestamp = datetime.datetime.now() has_constraints = bool(self._data_processor._constraints) has_batches = batch_size is not None and batch_size != num_rows show_progress_bar = has_constraints or has_batches - return self._sample_with_progress_bar( + sampled_data = self._sample_with_progress_bar( num_rows, max_tries_per_batch, batch_size, @@ -818,6 +887,24 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) + SYNTHESIZER_LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + sample_timestamp, + self.__class__.__name__, + len(sampled_data), + len(sampled_data.columns), + self._synthesizer_id, + ) + + return sampled_data + def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None): """Sample rows with conditions. diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index c9309b45c..63b22d22b 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -121,7 +121,8 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, numerical_distributions=None, default_distribution=None): + pac=10, cuda=True, numerical_distributions=None, default_distribution=None, + table_name=None): super().__init__( metadata, @@ -142,6 +143,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, epochs=epochs, pac=pac, cuda=cuda, + table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 4fc213949..c19b7536d 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -91,12 +91,14 @@ def get_distribution_class(cls, distribution): return cls._DISTRIBUTIONS[distribution] def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], numerical_distributions=None, default_distribution=None): + locales=['en_US'], numerical_distributions=None, default_distribution=None, + table_name=None): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, + table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index c6c5d3d0c..d59c3fca0 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -155,13 +155,14 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True): + pac=10, cuda=True, table_name=None): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - locales=locales + locales=locales, + table_name=table_name ) self.embedding_dim = embedding_dim @@ -338,12 +339,14 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, + table_name=None): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, + table_name=table_name ) self.embedding_dim = embedding_dim self.compress_dims = compress_dims diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4cc189975..8c2a97ac6 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2,9 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path +from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from faker import Faker from rdt.transformers import FloatFormatter @@ -148,7 +151,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'characters' } families_params = hmasynthesizer.get_table_parameters('families') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -157,7 +161,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'families' } char_families_params = hmasynthesizer.get_table_parameters('character_families') assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -166,7 +171,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'character_families' } assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' @@ -551,7 +557,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): custom_synthesizer.set_table_parameters( table_name='hotels', table_parameters={ - 'default_distribution': 'truncnorm' + 'default_distribution': 'truncnorm', } ) @@ -1664,3 +1670,74 @@ def test_hma_relationship_validity(): # Assert assert report.get_details('Relationship Validity')['Score'].mean() == 1.0 + + +@patch('sdv.multi_table.base.generate_synthesizer_id') +@patch('sdv.multi_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data, metadata = download_demo('multi_table', 'fake_hotels') + + # Run + instance = HMASynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 11\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(1) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + # Assert + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 73b25c97c..8c7ea2601 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -2,10 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder @@ -777,3 +779,79 @@ def test_fit_raises_version_error(): ) with pytest.raises(VersionError, match=expected_message): instance.fit(data) + + +@patch('sdv.single_table.base.generate_synthesizer_id') +@patch('sdv.single_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data = pd.DataFrame({ + 'col 1': [1, 2, 3], + 'col 2': [4, 5, 6], + 'col 3': ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + + # Run + instance = GaussianCopulaSynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(100) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 100\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) diff --git a/tests/unit/logging/__init__.py b/tests/unit/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py new file mode 100644 index 000000000..316ae9083 --- /dev/null +++ b/tests/unit/logging/test_utils.py @@ -0,0 +1,85 @@ +"""Test ``SDV`` logging utilities.""" +import logging +from unittest.mock import Mock, mock_open, patch + +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config + + +def test_get_sdv_logger_config(): + """Test the ``get_sdv_logger_config``. + + Test that a ``yaml_content`` is being converted to ``dictionary`` and is returned + by the ``get_sdv_logger_config``. + """ + yaml_content = """ + log_registry: 'local' + loggers: + test_logger: + level: DEBUG + handlers: + class: logging.StreamHandler + """ + # Run + with patch('builtins.open', mock_open(read_data=yaml_content)): + # Test if the function returns a dictionary + logger_conf = get_sdv_logger_config() + + # Assert + assert isinstance(logger_conf, dict) + assert logger_conf == { + 'log_registry': 'local', + 'loggers': { + 'test_logger': { + 'level': 'DEBUG', + 'handlers': { + 'class': 'logging.StreamHandler' + } + } + } + } + + +@patch('sdv.logging.utils.logging.getLogger') +def test_disable_single_table_logger(mock_getlogger): + # Setup + mock_logger = Mock() + handler = Mock() + mock_logger.handlers = [handler] + mock_logger.removeHandler.side_effect = lambda x: mock_logger.handlers.pop(0) + mock_logger.addHandler.side_effect = lambda x: mock_logger.handlers.append(x) + mock_getlogger.return_value = mock_logger + + # Run + with disable_single_table_logger(): + assert len(mock_logger.handlers) == 0 + + # Assert + assert len(mock_logger.handlers) == 1 + + +@patch('sdv.logging.utils.logging.StreamHandler') +@patch('sdv.logging.utils.logging.getLogger') +@patch('sdv.logging.utils.get_sdv_logger_config') +def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler): + # Setup + mock_logger_conf = { + 'log_registry': 'local', + 'loggers': { + 'test_logger': { + 'level': 'DEBUG', + 'handlers': { + 'class': 'logging.StreamHandler' + } + } + } + } + mock_get_sdv_logger_config.return_value = mock_logger_conf + mock_logger_instance = Mock() + mock_getlogger.return_value = mock_logger_instance + + # Run + get_sdv_logger('test_logger') + + # Assert + mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) + mock_logger_instance.addHandler.assert_called_once() diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 7f53f0cfc..34004325c 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1,6 +1,7 @@ """Test Multi Table Metadata.""" import json +import logging import re from collections import defaultdict from unittest.mock import Mock, call, patch @@ -12,7 +13,7 @@ from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata, SingleTableMetadata -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata class TestMultiTableMetadata: @@ -2843,7 +2844,8 @@ def test_save_to_json_file_exists(self, mock_path): with pytest.raises(ValueError, match=error_msg): instance.save_to_json('filepath.json') - def test_save_to_json(self, tmp_path): + @patch('sdv.metadata.multi_table.datetime') + def test_save_to_json(self, mock_datetime, tmp_path, caplog): """Test the ``save_to_json`` method. Test that ``save_to_json`` stores a ``json`` file and dumps the instance dict into @@ -2860,16 +2862,26 @@ def test_save_to_json(self, tmp_path): # Setup instance = MultiTableMetadata() instance._reset_updated_flag = Mock() + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run / Assert file_name = tmp_path / 'multitable.json' - instance.save_to_json(file_name) + with catch_sdv_logs(caplog, logging.INFO, logger='MultiTableMetadata'): + instance.save_to_json(file_name) with open(file_name, 'rb') as multi_table_file: saved_metadata = json.load(multi_table_file) assert saved_metadata == instance.to_dict() instance._reset_updated_flag.assert_called_once() + assert caplog.messages[0] == ( + '\nMetadata Save:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 0\n' + ' Total number of columns: 0\n' + ' Total number of relationships: 0' + ) def test__convert_relationships(self): """Test the ``_convert_relationships`` method. diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index 1b41414b0..d51b08ecc 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1,6 +1,7 @@ """Test Single Table Metadata.""" import json +import logging import re import warnings from datetime import datetime @@ -13,6 +14,7 @@ from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.single_table import SingleTableMetadata +from tests.utils import catch_sdv_logs class TestSingleTableMetadata: @@ -2835,7 +2837,8 @@ def test_save_to_json_file_exists(self, mock_path): with pytest.raises(ValueError, match=error_msg): instance.save_to_json('filepath.json') - def test_save_to_json(self, tmp_path): + @patch('sdv.metadata.single_table.datetime') + def test_save_to_json(self, mock_datetime, tmp_path, caplog): """Test the ``save_to_json`` method. Test that ``save_to_json`` stores a ``json`` file and dumps the instance dict into @@ -2850,12 +2853,23 @@ def test_save_to_json(self, tmp_path): - Creates a json representation of the instance. """ # Setup + mock_datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = SingleTableMetadata() - # Run / Assert + # Run file_name = tmp_path / 'singletable.json' - instance.save_to_json(file_name) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableMetadata'): + instance.save_to_json(file_name) + # Assert + assert caplog.messages[0] == ( + '\nMetadata Save:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 1' + ' Total number of columns: 0' + ' Total number of relationships: 0' + ) with open(file_name, 'rb') as single_table_file: saved_metadata = json.load(single_table_file) assert saved_metadata == instance.to_dict() diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 99330715e..b4de98e6a 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1,3 +1,4 @@ +import logging import re import warnings from collections import defaultdict @@ -17,7 +18,7 @@ from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata class TestBaseMultiTableSynthesizer: @@ -51,9 +52,10 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales), - call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales) + locales=locales, table_name='nesreca'), + call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'), + call(metadata=instance.metadata.tables['upravna_enota'], locales=locales, + table_name='upravna_enota') ]) def test__get_pbar_args(self): @@ -100,22 +102,26 @@ def test__print(self, mock_print): # Assert mock_print.assert_called_once_with('Fitting', end='') + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.BaseMultiTableSynthesizer._check_metadata_updated') - def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_id): + def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_id, + mock_datetime, caplog): """Test that when creating a new instance this sets the defaults. Test that the metadata object is being stored and also being validated. Afterwards, this calls the ``self._initialize_models`` which creates the initial instances of those. """ # Setup - synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() metadata.validate = Mock() # Run - instance = BaseMultiTableSynthesizer(metadata) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + instance = BaseMultiTableSynthesizer(metadata) # Assert assert instance.metadata == metadata @@ -127,6 +133,11 @@ def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_i mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) assert instance._synthesizer_id == synthesizer_id + assert caplog.messages[0] == ( + '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' + 'BaseMultiTableSynthesizer\n Synthesizer id: ' + 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test__init__column_relationship_warning(self): """Test that a warning is raised only once when the metadata has column relationships.""" @@ -269,6 +280,7 @@ def test_get_table_parameters_empty(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], + 'table_name': 'oseba', 'numerical_distributions': {} } } @@ -289,6 +301,7 @@ def test_get_table_parameters_has_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], + 'table_name': 'oseba', 'numerical_distributions': {} } @@ -320,13 +333,17 @@ def test_set_table_parameters(self): # Assert table_parameters = instance.get_table_parameters('oseba') - assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'} + assert instance._table_parameters['oseba'] == { + 'default_distribution': 'gamma', + 'table_name': 'oseba' + } assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert table_parameters['synthesizer_parameters'] == { 'default_distribution': 'gamma', 'enforce_min_max_values': True, 'locales': ['en_US'], 'enforce_rounding': True, + 'table_name': 'oseba', 'numerical_distributions': {} } @@ -818,27 +835,43 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) - def test_fit_processed_data(self): + @patch('sdv.multi_table.base.datetime') + def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. Ensure that the ``fit_processed_data`` augments the tables and then models those using the ``_model_tables`` method. Then sets the state to fitted. """ # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, - _fitted_sdv_enterprise_version=None + _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) - data = Mock() - data.copy.return_value = data + processed_data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert - instance._augment_tables.assert_called_once_with(data) + instance._augment_tables.assert_called_once_with(processed_data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted + assert caplog.messages[0] == ( + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_processed_data_empty_table(self): """Test attributes are properly set when data is empty and that _fit is not called.""" @@ -847,10 +880,13 @@ def test_fit_processed_data_empty_table(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None ) - data = pd.DataFrame() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._fit.assert_not_called() @@ -866,7 +902,10 @@ def test_fit_processed_data_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run and Assert error_msg = ( @@ -875,32 +914,49 @@ def test_fit_processed_data_raises_version_error(self): 'Please create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance.preprocess.assert_not_called() instance.fit_processed_data.assert_not_called() instance._check_metadata_updated.assert_not_called() + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base._validate_foreign_keys_not_null') - def test_fit(self, mock_validate_foreign_keys_not_null): + def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, - _fitted_sdv_enterprise_version=None + _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run - BaseMultiTableSynthesizer.fit(instance, data) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.fit(instance, data) # Assert mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data) instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() + assert caplog.messages[0] == ( + '\nFit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_raises_version_error(self): """Test that fit will raise a ``VersionError`` if the current version is bigger.""" @@ -910,7 +966,10 @@ def test_fit_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run and Assert error_msg = ( @@ -986,18 +1045,38 @@ def test_sample_validate_input(self): with pytest.raises(SynthesizerInputError, match=msg): instance.sample(scale=scale) - def test_sample(self): + @patch('sdv.multi_table.base.datetime') + def test_sample(self, mock_datetime, caplog): """Test that ``sample`` calls the ``_sample`` with the given arguments.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - instance._sample = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } + instance._sample = Mock(return_value=data) + + synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + instance._synthesizer_id = synth_id # Run - instance.sample(scale=1.5) + with catch_sdv_logs(caplog, logging.INFO, logger='MultiTableSynthesizer'): + instance.sample(scale=1.5) # Assert instance._sample.assert_called_once_with(scale=1.5) + assert caplog.messages[0] == ( + '\nSample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: BaseMultiTableSynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_get_learned_distributions_raises_an_unfitted_error(self): """Test that ``get_learned_distributions`` raises an error when model is not fitted.""" @@ -1386,19 +1465,31 @@ def test_get_info_with_enterprise(self, mock_version): 'fitted_sdv_enterprise_version': '1.1.0' } + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.cloudpickle') - def test_save(self, cloudpickle_mock, tmp_path): + def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup - synthesizer = Mock() + synthesizer = Mock( + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run filepath = tmp_path / 'output.pkl' - BaseMultiTableSynthesizer.save(synthesizer, filepath) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.save(synthesizer, filepath) # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) + assert caplog.messages[0] == ( + '\nSave:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.check_synthesizer_version') @patch('sdv.multi_table.base.check_sdv_versions_and_warn') @@ -1406,16 +1497,18 @@ def test_save(self, cloudpickle_mock, tmp_path): @patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, mock_check_synthesizer_version, - mock_generate_synthesizer_id): + mock_generate_synthesizer_id, mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' mock_generate_synthesizer_id.return_value = synthesizer_id synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) cloudpickle_mock.load.return_value = synthesizer_mock # Run - loaded_instance = BaseMultiTableSynthesizer.load('synth.pkl') + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + loaded_instance = BaseMultiTableSynthesizer.load('synth.pkl') # Assert mock_file.assert_called_once_with('synth.pkl', 'rb') @@ -1425,3 +1518,9 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) assert loaded_instance._synthesizer_id == synthesizer_id mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) + assert caplog.messages[0] == ( + '\nLoad:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 3dab339ba..6bbf7ef6f 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self): # Assert assert synthesizer == instance._synthesizer.return_value assert synthesizer._data_processor == table_synthesizer._data_processor - instance._synthesizer.assert_called_once_with(table_meta, a=1) + instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, {'colA': 'default_param', 'colB': 'default_param'} @@ -674,7 +674,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -823,7 +823,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): table_name='child_uniform', table_parameters={'default_distribution': 'uniform'} ) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) distributions = synthesizer._get_distributions() # Run estimation @@ -953,7 +953,7 @@ def test__estimate_num_columns_to_be_modeled(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -1068,7 +1068,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 509cee456..289cf7519 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1,3 +1,4 @@ +import logging import re from datetime import date, datetime from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch @@ -17,6 +18,7 @@ from sdv.single_table import ( CopulaGANSynthesizer, CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer) from sdv.single_table.base import COND_IDX, BaseSingleTableSynthesizer +from tests.utils import catch_sdv_logs class TestBaseSingleTableSynthesizer: @@ -59,19 +61,22 @@ def test__check_metadata_updated(self): # Assert instance.metadata._updated = False + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.generate_synthesizer_id') @patch('sdv.single_table.base.DataProcessor') @patch('sdv.single_table.base.BaseSingleTableSynthesizer._check_metadata_updated') def test___init__(self, mock_check_metadata_updated, mock_data_processor, - mock_generate_synthesizer_id): + mock_generate_synthesizer_id, mock_datetime, caplog): """Test instantiating with default values.""" # Setup metadata = Mock() synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run - instance = BaseSingleTableSynthesizer(metadata) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableSynthesizer'): + instance = BaseSingleTableSynthesizer(metadata) # Assert assert instance.enforce_min_max_values is True @@ -84,11 +89,17 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales + locales=instance.locales, + table_name=None ) metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) + assert caplog.messages[0] == ( + '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' + 'BaseSingleTableSynthesizer\n Synthesizer id: ' + 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) @patch('sdv.single_table.base.DataProcessor') def test___init__custom(self, mock_data_processor): @@ -113,7 +124,8 @@ def test___init__custom(self, mock_data_processor): metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales + locales=instance.locales, + table_name=None ) metadata.validate.assert_called_once_with() @@ -172,7 +184,8 @@ def test_get_parameters(self, mock_data_processor): assert parameters == { 'enforce_min_max_values': False, 'enforce_rounding': False, - 'locales': 'en_CA' + 'locales': 'en_CA', + 'table_name': None } @patch('sdv.single_table.base.DataProcessor') @@ -341,21 +354,35 @@ def test__fit(self, mock_data_processor): with pytest.raises(NotImplementedError, match=''): instance._fit(data) - def test_fit_processed_data(self): + @patch('sdv.single_table.base.datetime') + def test_fit_processed_data(self, mock_datetime, caplog): """Test that ``fit_processed_data`` calls the ``_fit``.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) - processed_data = Mock() - processed_data.empty = False + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) # Run - BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._fit.assert_called_once_with(processed_data) + assert caplog.messages[0] == ( + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 1\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_processed_data_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -367,8 +394,9 @@ def test_fit_processed_data_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, + table_name=None ) - processed_data = Mock() + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -381,30 +409,45 @@ def test_fit_processed_data_raises_version_error(self): with pytest.raises(VersionError, match=error_msg): BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) - def test_fit(self): + @patch('sdv.single_table.base.datetime') + def test_fit(self, mock_datetime, caplog): """Test that ``fit`` calls ``preprocess`` and the ``fit_processed_data``. When fitting, the synthsizer has to ``preprocess`` the data and with the output of this method, call the ``fit_processed_data`` """ # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True instance._fitted = True # Run - BaseSingleTableSynthesizer.fit(instance, processed_data) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.fit(instance, data) # Assert assert instance._random_state_set is False instance._data_processor.reset_sampling.assert_called_once_with() - instance._preprocess.assert_called_once_with(processed_data) + instance._preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) instance._check_metadata_updated.assert_called_once() + assert caplog.messages[0] == ( + '\nFit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 2\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -416,8 +459,9 @@ def test_fit_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, + table_name=None ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -428,7 +472,7 @@ def test_fit_raises_version_error(self): 'create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseSingleTableSynthesizer.fit(instance, processed_data) + BaseSingleTableSynthesizer.fit(instance, data) def test__validate_constraints(self): """Test that ``_validate_constraints`` calls ``fit`` and returns any errors.""" @@ -1362,24 +1406,31 @@ def test__sample_with_progress_bar_removing_temp_file( mock_os.remove.assert_called_once_with('.sample.csv.temp') mock_os.path.exists.assert_called_once_with('.sample.csv.temp') - def test_sample(self): + @patch('sdv.single_table.base.datetime') + def test_sample(self, mock_datetime, caplog): """Test that we use ``_sample_with_progress_bar`` in this method.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' num_rows = 10 max_tries_per_batch = 50 batch_size = 5 output_file_path = 'temp.csv' - instance = Mock() + instance = Mock( + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None + ) instance.get_metadata.return_value._constraints = False + instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) # Run - result = BaseSingleTableSynthesizer.sample( - instance, - num_rows, - max_tries_per_batch, - batch_size, - output_file_path, - ) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableSynthesizer'): + result = BaseSingleTableSynthesizer.sample( + instance, + num_rows, + max_tries_per_batch, + batch_size, + output_file_path, + ) # Assert instance._sample_with_progress_bar.assert_called_once_with( @@ -1389,7 +1440,17 @@ def test_sample(self): 'temp.csv', show_progress_bar=True ) - assert result == instance._sample_with_progress_bar.return_value + pd.testing.assert_frame_equal(result, pd.DataFrame({'col': [1, 2, 3]})) + assert caplog.messages[0] == ( + '\nSample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 1\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields.""" @@ -1742,35 +1803,51 @@ def test__validate_known_columns_a_few_nans(self): with pytest.warns(UserWarning, match=warn_msg): synthesizer._validate_known_columns(conditions) + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.cloudpickle') - def test_save(self, cloudpickle_mock, tmp_path): + def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup - synthesizer = Mock() + synthesizer = Mock( + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None + ) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run filepath = tmp_path / 'output.pkl' - BaseSingleTableSynthesizer.save(synthesizer, filepath) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.save(synthesizer, filepath) # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) + assert caplog.messages[0] == ( + '\nSave:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.generate_synthesizer_id') @patch('sdv.single_table.base.check_synthesizer_version') @patch('sdv.single_table.base.check_sdv_versions_and_warn') @patch('sdv.single_table.base.cloudpickle') @patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, - mock_check_synthesizer_version, mock_generate_synthesizer_id): + mock_check_synthesizer_version, mock_generate_synthesizer_id, + mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) + synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id cloudpickle_mock.load.return_value = synthesizer_mock # Run - loaded_instance = BaseSingleTableSynthesizer.load('synth.pkl') + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + loaded_instance = BaseSingleTableSynthesizer.load('synth.pkl') # Assert mock_file.assert_called_once_with('synth.pkl', 'rb') @@ -1780,6 +1857,12 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war assert loaded_instance._synthesizer_id == synthesizer_id mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) + assert caplog.messages[0] == ( + '\nLoad:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_load_custom_constraint_classes(self): """Test that ``load_custom_constraint_classes`` calls the ``DataProcessor``'s method.""" diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index a818dc85a..762e28f58 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -175,6 +175,7 @@ def test_get_params(self): 'cuda': True, 'numerical_distributions': {}, 'default_distribution': 'beta', + 'table_name': None } @patch('sdv.single_table.copulagan.rdt') diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 3c96028c3..02ec24b14 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -130,7 +130,8 @@ def test_get_parameters(self): 'enforce_rounding': True, 'locales': ['en_US'], 'numerical_distributions': {}, - 'default_distribution': 'beta' + 'default_distribution': 'beta', + 'table_name': None } @patch('sdv.single_table.copulas.LOGGER') diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 967823094..e18e27552 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -151,6 +151,7 @@ def test_get_parameters(self): 'epochs': 300, 'pac': 10, 'cuda': True, + 'table_name': None } def test__estimate_num_columns(self): @@ -426,6 +427,7 @@ def test_get_parameters(self): 'epochs': 300, 'loss_factor': 2, 'cuda': True, + 'table_name': None } @patch('sdv.single_table.ctgan.TVAE') diff --git a/tests/utils.py b/tests/utils.py index a1d819eb7..bf13b9f02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ """Utils for testing.""" +import contextlib + import pandas as pd +from sdv.logging import get_sdv_logger from sdv.metadata.multi_table import MultiTableMetadata @@ -99,3 +102,17 @@ def get_multi_table_data(): } return data + + +@contextlib.contextmanager +def catch_sdv_logs(caplog, level, logger): + """Context manager to capture logs from an SDV logger.""" + logger = get_sdv_logger(logger) + orig_level = logger.level + logger.setLevel(level) + logger.addHandler(caplog.handler) + try: + yield + finally: + logger.setLevel(orig_level) + logger.removeHandler(caplog.handler) From 8aa3b5ef86f807fbd519d9709c22740b041d8852 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Mon, 29 Apr 2024 12:50:04 +0200 Subject: [PATCH 09/32] Include `sdv_logger_config.yml` with the package (#1963) --- pyproject.toml | 9 +++++---- sdv/multi_table/base.py | 3 --- sdv/multi_table/hma.py | 12 +++--------- sdv/sequential/par.py | 2 +- sdv/single_table/base.py | 23 ++++++++++------------- sdv/single_table/copulagan.py | 4 +--- sdv/single_table/copulas.py | 4 +--- sdv/single_table/ctgan.py | 7 ++----- tests/integration/multi_table/test_hma.py | 9 +++------ tests/unit/multi_table/test_base.py | 15 ++++----------- tests/unit/multi_table/test_hma.py | 2 +- tests/unit/single_table/test_base.py | 19 +++++-------------- tests/unit/single_table/test_copulagan.py | 3 +-- tests/unit/single_table/test_copulas.py | 3 +-- tests/unit/single_table/test_ctgan.py | 6 ++---- 15 files changed, 40 insertions(+), 81 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e75e0b4b..9326308d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,11 +25,11 @@ dependencies = [ 'botocore>=1.31', 'cloudpickle>=2.1.0', 'graphviz>=0.13.2', - "numpy>=1.20.0;python_version<'3.10'", + "numpy>=1.21.0;python_version<'3.10'", "numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0,<2;python_version>='3.12'", - "pandas>=1.1.3;python_version<'3.10'", - "pandas>=1.3.4;python_version>='3.10' and python_version<'3.11'", + "pandas>=1.4.0;python_version<'3.10'", + "pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'", "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'", "pandas>=2.1.1;python_version>='3.12'", 'tqdm>=4.29', @@ -141,7 +141,8 @@ namespaces = false 'make.bat', '*.jpg', '*.png', - '*.gif' + '*.gif', + 'sdv_logger_config.yml' ] [tool.setuptools.exclude-package-data] diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 4afbee0e9..00efe700e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -65,7 +65,6 @@ def _initialize_models(self): self._table_synthesizers[table_name] = self._synthesizer( metadata=table_metadata, locales=self.locales, - table_name=table_name, **synthesizer_parameters ) @@ -200,8 +199,6 @@ def set_table_parameters(self, table_name, table_parameters): A dictionary with the parameters as keys and the values to be used to instantiate the table's synthesizer. """ - # Ensure that we set the name of the table no matter what - table_parameters.update({'table_name': table_name}) self._table_synthesizers[table_name] = self._synthesizer( metadata=self.metadata.tables[table_name], **table_parameters diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 1e9ce6c2a..9f4d5da30 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -312,11 +312,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: - synthesizer_parameters = self._table_parameters[child_name] - synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[child_name] ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -523,11 +521,9 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] - synthesizer_parameters = self._table_parameters[child_name] - synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[child_name] ) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -622,11 +618,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() - synthesizer_parameters = self._table_parameters[table_name] - synthesizer_parameters.update({'table_name': table_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[table_name] ) synthesizer._set_parameters(parameters) try: diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index e06c96864..4c7a80a36 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -92,7 +92,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - locales=locales + locales=locales, ) sequence_key = self.metadata.sequence_key diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 41d271fd1..60656a4a2 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -88,7 +88,7 @@ def _check_metadata_updated(self): ) def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], table_name=None): + locales=['en_US']): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata self.metadata.validate() @@ -96,13 +96,11 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding self.locales = locales - self.table_name = table_name self._data_processor = DataProcessor( metadata=self.metadata, enforce_rounding=self.enforce_rounding, enforce_min_max_values=self.enforce_min_max_values, locales=self.locales, - table_name=self.table_name ) self._fitted = False self._random_state_set = False @@ -500,16 +498,15 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - if synthesizer.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nLoad:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - synthesizer.__class__.__name__, - synthesizer._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info( + '\nLoad:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) return synthesizer diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 63b22d22b..c9309b45c 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -121,8 +121,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, numerical_distributions=None, default_distribution=None, - table_name=None): + pac=10, cuda=True, numerical_distributions=None, default_distribution=None): super().__init__( metadata, @@ -143,7 +142,6 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, epochs=epochs, pac=pac, cuda=cuda, - table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index c19b7536d..4fc213949 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -91,14 +91,12 @@ def get_distribution_class(cls, distribution): return cls._DISTRIBUTIONS[distribution] def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], numerical_distributions=None, default_distribution=None, - table_name=None): + locales=['en_US'], numerical_distributions=None, default_distribution=None): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, - table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index d59c3fca0..860c66487 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -155,14 +155,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, table_name=None): + pac=10, cuda=True): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, - table_name=table_name ) self.embedding_dim = embedding_dim @@ -339,14 +338,12 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, - table_name=None): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - table_name=table_name ) self.embedding_dim = embedding_dim self.compress_dims = compress_dims diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 8c2a97ac6..f7c499d28 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -151,8 +151,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'characters' + 'numerical_distributions': {} } families_params = hmasynthesizer.get_table_parameters('families') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -161,8 +160,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'families' + 'numerical_distributions': {} } char_families_params = hmasynthesizer.get_table_parameters('character_families') assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -171,8 +169,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'character_families' + 'numerical_distributions': {} } assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index b4de98e6a..0cad058b0 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -52,10 +52,9 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales, table_name='nesreca'), - call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales, - table_name='upravna_enota') + locales=locales), + call(metadata=instance.metadata.tables['oseba'], locales=locales), + call(metadata=instance.metadata.tables['upravna_enota'], locales=locales) ]) def test__get_pbar_args(self): @@ -280,7 +279,6 @@ def test_get_table_parameters_empty(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'table_name': 'oseba', 'numerical_distributions': {} } } @@ -301,7 +299,6 @@ def test_get_table_parameters_has_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'table_name': 'oseba', 'numerical_distributions': {} } @@ -333,17 +330,13 @@ def test_set_table_parameters(self): # Assert table_parameters = instance.get_table_parameters('oseba') - assert instance._table_parameters['oseba'] == { - 'default_distribution': 'gamma', - 'table_name': 'oseba' - } + assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'} assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert table_parameters['synthesizer_parameters'] == { 'default_distribution': 'gamma', 'enforce_min_max_values': True, 'locales': ['en_US'], 'enforce_rounding': True, - 'table_name': 'oseba', 'numerical_distributions': {} } diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 6bbf7ef6f..c40e7b080 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self): # Assert assert synthesizer == instance._synthesizer.return_value assert synthesizer._data_processor == table_synthesizer._data_processor - instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1) + instance._synthesizer.assert_called_once_with(table_meta, a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, {'colA': 'default_param', 'colB': 'default_param'} diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 289cf7519..197141e69 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -89,8 +89,7 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales, - table_name=None + locales=instance.locales ) metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() @@ -124,8 +123,7 @@ def test___init__custom(self, mock_data_processor): metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales, - table_name=None + locales=instance.locales ) metadata.validate.assert_called_once_with() @@ -184,8 +182,7 @@ def test_get_parameters(self, mock_data_processor): assert parameters == { 'enforce_min_max_values': False, 'enforce_rounding': False, - 'locales': 'en_CA', - 'table_name': None + 'locales': 'en_CA' } @patch('sdv.single_table.base.DataProcessor') @@ -362,8 +359,7 @@ def test_fit_processed_data(self, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) @@ -394,7 +390,6 @@ def test_fit_processed_data_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, - table_name=None ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -422,7 +417,6 @@ def test_fit(self, mock_datetime, caplog): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True @@ -459,7 +453,6 @@ def test_fit_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, - table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -1417,7 +1410,6 @@ def test_sample(self, mock_datetime, caplog): output_file_path = 'temp.csv' instance = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) @@ -1810,7 +1802,6 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): # Setup synthesizer = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1839,7 +1830,7 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None) + synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index 762e28f58..6909c86d2 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -174,8 +174,7 @@ def test_get_params(self): 'pac': 10, 'cuda': True, 'numerical_distributions': {}, - 'default_distribution': 'beta', - 'table_name': None + 'default_distribution': 'beta' } @patch('sdv.single_table.copulagan.rdt') diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 02ec24b14..3c96028c3 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -130,8 +130,7 @@ def test_get_parameters(self): 'enforce_rounding': True, 'locales': ['en_US'], 'numerical_distributions': {}, - 'default_distribution': 'beta', - 'table_name': None + 'default_distribution': 'beta' } @patch('sdv.single_table.copulas.LOGGER') diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index e18e27552..ddbdfc91c 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -150,8 +150,7 @@ def test_get_parameters(self): 'verbose': False, 'epochs': 300, 'pac': 10, - 'cuda': True, - 'table_name': None + 'cuda': True } def test__estimate_num_columns(self): @@ -426,8 +425,7 @@ def test_get_parameters(self): 'batch_size': 500, 'epochs': 300, 'loss_factor': 2, - 'cuda': True, - 'table_name': None + 'cuda': True } @patch('sdv.single_table.ctgan.TVAE') From be3b72ca1774cbcf64ebb8b55364f9ec1d68474a Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Mon, 29 Apr 2024 22:31:55 +0200 Subject: [PATCH 10/32] Implement CSVHandler (#1958) --- sdv/io/local/__init__.py | 8 + sdv/io/local/local.py | 194 +++++++++++++++++++++ tests/integration/io/__init__.py | 0 tests/integration/io/local/__init__.py | 0 tests/integration/io/local/test_local.py | 32 ++++ tests/unit/io/__init__.py | 0 tests/unit/io/local/__init__.py | 0 tests/unit/io/local/test_local.py | 212 +++++++++++++++++++++++ 8 files changed, 446 insertions(+) create mode 100644 sdv/io/local/__init__.py create mode 100644 sdv/io/local/local.py create mode 100644 tests/integration/io/__init__.py create mode 100644 tests/integration/io/local/__init__.py create mode 100644 tests/integration/io/local/test_local.py create mode 100644 tests/unit/io/__init__.py create mode 100644 tests/unit/io/local/__init__.py create mode 100644 tests/unit/io/local/test_local.py diff --git a/sdv/io/local/__init__.py b/sdv/io/local/__init__.py new file mode 100644 index 000000000..a233b25be --- /dev/null +++ b/sdv/io/local/__init__.py @@ -0,0 +1,8 @@ +"""Local I/O module.""" + +from sdv.io.local.local import BaseLocalHandler, CSVHandler + +__all__ = ( + 'BaseLocalHandler', + 'CSVHandler' +) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py new file mode 100644 index 000000000..0d81ab634 --- /dev/null +++ b/sdv/io/local/local.py @@ -0,0 +1,194 @@ +"""Local file handlers.""" +import codecs +import inspect +import os +from pathlib import Path + +import pandas as pd + +from sdv.metadata import MultiTableMetadata + + +class BaseLocalHandler: + """Base class for local handlers.""" + + def __init__(self, decimal='.', float_format=None): + self.decimal = decimal + self.float_format = float_format + + def _infer_metadata(self, data): + """Detect the metadata for all tables in a dictionary of dataframes. + + Args: + data (dict): + Dictionary of table names to dataframes. + + Returns: + MultiTableMetadata: + An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata + properties from the data. + """ + metadata = MultiTableMetadata() + metadata.detect_from_dataframes(data) + return metadata + + def read(self): + """Read data from files and returns it along with metadata. + + This method must be implemented by subclasses. + + Returns: + tuple: + A tuple containing the read data as a dictionary and metadata. The dictionary maps + table names to pandas DataFrames. The metadata is an object describing the data. + """ + raise NotImplementedError() + + def write(self): + """Write data to files. + + This method must be implemented by subclasses. + """ + raise NotImplementedError() + + +class CSVHandler(BaseLocalHandler): + """A class for handling CSV files. + + Args: + sep (str): + The separator used for reading and writing CSV files. Defaults to ``,``. + encoding (str): + The character encoding to use for reading and writing CSV files. Defaults to ``UTF``. + decimal (str): + The character used to denote the decimal point. Defaults to ``.``. + float_format (str or None): + The formatting string for floating-point numbers. Optional. + quotechar (str): + Character used to denote the start and end of a quoted item. + Quoted items can include the delimiter and it will be ignored. Defaults to '"'. + quoting (int or None): + Control field quoting behavior. Default is 0. + + Raises: + ValueError: + If the provided encoding is not available in the system. + """ + + def __init__(self, sep=',', encoding='UTF', decimal='.', float_format=None, + quotechar='"', quoting=0): + super().__init__(decimal, float_format) + try: + codecs.lookup(encoding) + except LookupError as error: + raise ValueError( + f"The provided encoding '{encoding}' is not available in your system." + ) from error + + self.sep = sep + self.encoding = encoding + self.quotechar = quotechar + self.quoting = quoting + + def read(self, folder_name, file_names=None): + """Read data from CSV files and returns it along with metadata. + + Args: + folder_name (str): + The name of the folder containing CSV files. + file_names (list of str, optional): + The names of CSV files to read. If None, all files ending with '.csv' + in the folder are read. + + Returns: + tuple: + A tuple containing the data as a dictionary and metadata. The dictionary maps + table names to pandas DataFrames. The metadata is an object describing the data. + + Raises: + FileNotFoundError: + If the specified files do not exist in the folder. + """ + data = {} + metadata = MultiTableMetadata() + + folder_path = Path(folder_name) + + if file_names is None: + # If file_names is None, read all files in the folder ending with ".csv" + file_paths = folder_path.glob('*.csv') + else: + # Validate if the given files exist in the folder + file_names = file_names + missing_files = [ + file + for file in file_names + if not (folder_path / file).exists() + ] + if missing_files: + raise FileNotFoundError( + f"The following files do not exist in the folder: {', '.join(missing_files)}." + ) + + file_paths = [folder_path / file for file in file_names] + + # Read CSV files + kwargs = { + 'sep': self.sep, + 'encoding': self.encoding, + 'parse_dates': False, + 'low_memory': False, + 'decimal': self.decimal, + 'on_bad_lines': 'warn', + 'quotechar': self.quotechar, + 'quoting': self.quoting + } + + args = inspect.getfullargspec(pd.read_csv) + if 'on_bad_lines' not in args.kwonlyargs: + kwargs.pop('on_bad_lines') + kwargs['error_bad_lines'] = False + + for file_path in file_paths: + table_name = file_path.stem # Remove file extension to get table name + data[table_name] = pd.read_csv( + file_path, + **kwargs + ) + + metadata = self._infer_metadata(data) + return data, metadata + + def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'): + """Write synthetic data to CSV files. + + Args: + synthetic_data (dict): + A dictionary mapping table names to pandas DataFrames containing synthetic data. + folder_name (str): + The name of the folder to write CSV files to. + file_name_suffix (str, optional): + An optional suffix to add to each file name. If ``None``, no suffix is added. + mode (str, optional): + The mode of writing to use. Defaults to 'x'. + 'x': Write to new files, raising errors if existing files exist with the same name. + 'w': Write to new files, clearing any existing files that exist. + 'a': Append the new CSV rows to any existing files. + """ + folder_path = Path(folder_name) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + for table_name, table_data in synthetic_data.items(): + file_name = f'{table_name}{file_name_suffix}' if file_name_suffix else f'{table_name}' + file_path = f'{folder_path / file_name}.csv' + table_data.to_csv( + file_path, + sep=self.sep, + encoding=self.encoding, + index=False, + float_format=self.float_format, + quotechar=self.quotechar, + quoting=self.quoting, + mode=mode, + ) diff --git a/tests/integration/io/__init__.py b/tests/integration/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/io/local/__init__.py b/tests/integration/io/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py new file mode 100644 index 000000000..87b3c80ea --- /dev/null +++ b/tests/integration/io/local/test_local.py @@ -0,0 +1,32 @@ +import pandas as pd + +from sdv.io.local import CSVHandler +from sdv.metadata import MultiTableMetadata + + +class TestCSVHandler: + + def test_integration_read_write(self, tmpdir): + """Test end to end the read and write methods of ``CSVHandler``.""" + # Prepare synthetic data + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + # Write synthetic data to CSV files + handler = CSVHandler() + handler.write(synthetic_data, tmpdir) + + # Read data from CSV files + data, metadata = handler.read(tmpdir) + + # Check if data was read correctly + assert len(data) == 2 + assert 'table1' in data + assert 'table2' in data + assert isinstance(metadata, MultiTableMetadata) is True + + # Check if the dataframes match the original synthetic data + pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) + pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2']) diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/io/local/__init__.py b/tests/unit/io/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py new file mode 100644 index 000000000..e69d18636 --- /dev/null +++ b/tests/unit/io/local/test_local.py @@ -0,0 +1,212 @@ +"""Unit tests for local file handlers.""" +import os +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pytest + +from sdv.io.local.local import CSVHandler +from sdv.metadata.multi_table import MultiTableMetadata + + +class TestCSVHandler: + + def test___init__(self): + """Test the dafault initialization of the class.""" + # Run + instance = CSVHandler() + + # Assert + assert instance.decimal == '.' + assert instance.float_format is None + assert instance.encoding == 'UTF' + assert instance.sep == ',' + assert instance.quotechar == '"' + assert instance.quoting == 0 + + def test___init___custom(self): + """Test custom initialization of the class.""" + # Run + instance = CSVHandler( + sep=';', + encoding='utf-8', + decimal=',', + float_format='%.2f', + quotechar="'", + quoting=2 + ) + + # Assert + assert instance.decimal == ',' + assert instance.float_format == '%.2f' + assert instance.encoding == 'utf-8' + assert instance.sep == ';' + assert instance.quotechar == "'" + assert instance.quoting == 2 + + def test___init___error_encoding(self): + """Test custom initialization of the class.""" + # Run and Assert + error_msg = "The provided encoding 'sdvutf-8' is not available in your system." + with pytest.raises(ValueError, match=error_msg): + CSVHandler(sep=';', encoding='sdvutf-8', decimal=',', float_format='%.2f') + + @patch('sdv.io.local.local.Path.glob') + @patch('pandas.read_csv') + def test_read(self, mock_read_csv, mock_glob): + """Test the read method of CSVHandler class with a folder.""" + # Setup + mock_glob.return_value = [ + Path('/path/to/data/parent.csv'), + Path('/path/to/data/child.csv') + ] + mock_read_csv.side_effect = [ + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + ] + + handler = CSVHandler() + + # Run + data, metadata = handler.read('/path/to/data') + + # Assert + assert len(data) == 2 + assert 'parent' in data + assert 'child' in data + assert isinstance(metadata, MultiTableMetadata) + assert mock_read_csv.call_count == 2 + pd.testing.assert_frame_equal( + data['parent'], + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + ) + pd.testing.assert_frame_equal( + data['child'], + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + ) + + def test_read_files(self, tmpdir): + """Test the read method of CSVHandler class with given ``file_names``.""" + # Setup + file_path = Path(tmpdir) + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv( + file_path / 'parent.csv', + index=False + ) + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv( + file_path / 'child.csv', + index=False + ) + + handler = CSVHandler() + + # Run + data, metadata = handler.read(tmpdir, file_names=['parent.csv']) + + # Assert + assert 'parent' in data + pd.testing.assert_frame_equal( + data['parent'], + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + ) + + def test_read_files_missing(self, tmpdir): + """Test the read method of CSVHandler with missing ``file_names``.""" + # Setup + file_path = Path(tmpdir) + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv( + file_path / 'parent.csv', + index=False + ) + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv( + file_path / 'child.csv', + index=False + ) + + handler = CSVHandler() + + # Run and Assert + error_msg = 'The following files do not exist in the folder: grandchild.csv, parents.csv.' + with pytest.raises(FileNotFoundError, match=error_msg): + handler.read(tmpdir, file_names=['grandchild.csv', 'parents.csv']) + + def test_write(self, tmpdir): + """Test the write functionality of a CSVHandler.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + handler = CSVHandler() + + assert os.path.exists(tmpdir / 'synthetic_data') is False + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', file_name_suffix='_synthetic') + + # Assert + assert 'table1_synthetic.csv' in os.listdir(tmpdir / 'synthetic_data') + assert 'table2_synthetic.csv' in os.listdir(tmpdir / 'synthetic_data') + + def test_write_file_exists(self, tmpdir): + """Test that an error is raised when it exists and the mode is `x`.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + with pytest.raises(FileExistsError): + handler.write(synthetic_data, tmpdir / 'synthetic_data') + + def test_write_file_exists_mode_is_a(self, tmpdir): + """Test the write functionality of a CSVHandler when the mode is ``a``.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', mode='a') + + # Assert + dataframe = pd.read_csv(tmpdir / 'synthetic_data' / 'table1.csv') + expected_dataframe = pd.DataFrame({ + 'col1': ['1', '2', '3', 'col1', '1', '2', '3'], + 'col2': ['a', 'b', 'c', 'col2', 'a', 'b', 'c'] + }) + pd.testing.assert_frame_equal(dataframe, expected_dataframe) + + def test_write_file_exists_mode_is_w(self, tmpdir): + """Test the write functionality of a CSVHandler when the mode is ``w``.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', mode='w') + + # Assert + dataframe = pd.read_csv(tmpdir / 'synthetic_data' / 'table1.csv') + expected_dataframe = pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': ['a', 'b', 'c'] + }) + pd.testing.assert_frame_equal(dataframe, expected_dataframe) From 060bae99b8b0ceb51ddad37fada920911a90d385 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Tue, 30 Apr 2024 08:26:21 +0100 Subject: [PATCH 11/32] Add `get_random_subset` poc utility function (#1928) --- pyproject.toml | 3 +- sdv/multi_table/base.py | 7 - sdv/multi_table/hma.py | 2 +- sdv/multi_table/utils.py | 285 +++++- sdv/utils/poc.py | 96 +- tests/integration/utils/test_poc.py | 93 +- tests/unit/multi_table/test_utils.py | 1206 +++++++++++++++++++++++++- tests/unit/utils/test_poc.py | 174 +++- 8 files changed, 1792 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9326308d5..af250ec3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,7 @@ dependencies = [ "numpy>=1.21.0;python_version<'3.10'", "numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0,<2;python_version>='3.12'", - "pandas>=1.4.0;python_version<'3.10'", - "pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'", + "pandas>=1.4.0;python_version<'3.11'", "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'", "pandas>=2.1.1;python_version>='3.12'", 'tqdm>=4.29', diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 00efe700e..114f40739 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -127,13 +127,6 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._synthesizer_id ) - def _get_root_parents(self): - """Get the set of root parents in the graph.""" - non_root_tables = set(self.metadata._get_parent_map().keys()) - root_parents = set(self.metadata.tables.keys()) - non_root_tables - - return root_parents - def set_address_columns(self, table_name, column_names, anonymization_level='full'): """Set the address multi-column transformer. diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 9f4d5da30..a769c3e13 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -205,7 +205,7 @@ def get_learned_distributions(self, table_name): Dictionary containing the distributions used or detected for each column and the learned parameters for those. """ - if table_name not in self._get_root_parents(): + if table_name not in _get_root_tables(self.metadata.relationships): raise SynthesizerInputError( f"Learned distributions are not available for the '{table_name}' table. " 'Please choose a table that does not have any parents.' diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 2c372e6c6..cf3bb30e1 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -1,18 +1,26 @@ """Utility functions for the MultiTable models.""" import math +import warnings from collections import defaultdict from copy import deepcopy import numpy as np import pandas as pd -from sdv._utils import _get_root_tables +from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null +from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError from sdv.multi_table import HMASynthesizer from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS MODELABLE_SDTYPE = ['categorical', 'numerical', 'datetime', 'boolean'] +def _get_child_tables(relationships): + parent_tables = {rel['parent_table_name'] for rel in relationships} + child_tables = {rel['child_table_name'] for rel in relationships} + return child_tables - parent_tables + + def _get_relationships_for_child(relationships, child_table): return [rel for rel in relationships if rel['child_table_name'] == child_table] @@ -79,6 +87,34 @@ def _get_all_descendant_per_root_at_order_n(relationships, order): return all_descendants +def _get_ancestors(relationships, child_table): + """Get the ancestors of the child table.""" + ancestors = set() + parent_relationships = _get_relationships_for_child(relationships, child_table) + for relationship in parent_relationships: + parent_table = relationship['parent_table_name'] + ancestors.add(parent_table) + ancestors.update(_get_ancestors(relationships, parent_table)) + + return ancestors + + +def _get_disconnected_roots_from_table(relationships, table): + """Get the disconnected roots table from the given table.""" + root_tables = _get_root_tables(relationships) + child_tables = _get_child_tables(relationships) + if table in child_tables: + return root_tables - _get_ancestors(relationships, table) + + connected_roots = set() + for child in child_tables: + child_ancestor = _get_ancestors(relationships, child) + if table in child_ancestor: + connected_roots.update(root_tables.intersection(child_ancestor)) + + return root_tables - connected_roots + + def _simplify_relationships_and_tables(metadata, tables_to_drop): """Simplify the relationships and tables of the metadata. @@ -339,7 +375,7 @@ def _print_simplified_schema_summary(data_before, data_after): print('\n'.join(message)) # noqa: T001 -def _get_rows_to_drop(metadata, data): +def _get_rows_to_drop(data, metadata): """Get the rows to drop to ensure referential integrity. The logic of this function is to start at the root tables, look at invalid references @@ -392,3 +428,248 @@ def _get_rows_to_drop(metadata, data): relationships = [rel for rel in relationships if rel not in relationships_parent] return table_to_idx_to_drop + + +def _get_nan_fk_indices_table(data, relationships, table): + """Get the indexes of the rows to drop that have NaN foreign keys.""" + idx_with_nan_foreign_key = set() + relationships_for_table = _get_relationships_for_child(relationships, table) + for relationship in relationships_for_table: + child_column = relationship['child_foreign_key'] + idx_with_nan_foreign_key.update( + data[table][data[table][child_column].isna()].index + ) + + return idx_with_nan_foreign_key + + +def _drop_rows(data, metadata, drop_missing_values): + table_to_idx_to_drop = _get_rows_to_drop(data, metadata) + for table in sorted(metadata.tables): + idx_to_drop = table_to_idx_to_drop[table] + data[table] = data[table].drop(idx_to_drop) + if drop_missing_values: + idx_with_nan_fk = _get_nan_fk_indices_table( + data, metadata.relationships, table + ) + data[table] = data[table].drop(idx_with_nan_fk) + + if data[table].empty: + raise InvalidDataError([ + f"All references in table '{table}' are unknown and must be dropped." + 'Try providing different data for this table.' + ]) + + +def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep): + """Subsample the disconnected roots tables and their descendants.""" + relationships = metadata.relationships + roots = _get_disconnected_roots_from_table(relationships, table) + for root in roots: + data[root] = data[root].sample(frac=ratio_to_keep) + + _drop_rows(data, metadata, drop_missing_values=True) + + +def _subsample_table_and_descendants(data, metadata, table, num_rows): + """Subsample the table and its descendants. + + The logic is to first subsample all the NaN foreign keys of the table. + We raise an error if we cannot reach referential integrity while keeping the number of rows. + Then, we drop rows of the descendants to ensure referential integrity. + + Args: + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. + table (str): + Name of the table. + """ + idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table) + num_rows_to_drop = len(data[table]) - num_rows + if len(idx_nan_fk) > num_rows_to_drop: + raise SamplingError( + f"Referential integrity cannot be reached for table '{table}' while keeping " + f'{num_rows} rows. Please try again with a bigger number of rows.' + ) + else: + data[table] = data[table].drop(idx_nan_fk) + + data[table] = data[table].sample(num_rows) + _drop_rows(data, metadata, drop_missing_values=True) + + +def _get_primary_keys_referenced(data, metadata): + """Get the primary keys referenced by the relationships. + + Args: + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. + + Returns: + dict: + Dictionary that maps the table name to a set of their primary keys referenced. + """ + relationships = metadata.relationships + primary_keys_referenced = defaultdict(set) + for relationship in relationships: + parent_table = relationship['parent_table_name'] + child_table = relationship['child_table_name'] + foreign_key = relationship['child_foreign_key'] + primary_keys_referenced[parent_table].update(set(data[child_table][foreign_key].unique())) + + return primary_keys_referenced + + +def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_before, + dereferenced_pk_parent): + """Subsample the parent table. + + The strategy here is to: + - Drop the rows that are no longer referenced by the descendants. + - Drop a proportional amount of never-referenced rows. + + Args: + parent_table (pandas.DataFrame): + Parent table to subsample. + parent_primary_key (str): + Name of the primary key of the parent table. + parent_pk_referenced_before (set): + Set of the primary keys referenced before any subsampling. + dereferenced_pk_parent (set): + Set of the primary keys that are no longer referenced by the descendants. + + Returns: + pandas.DataFrame: + Subsampled parent table. + """ + total_referenced = len(parent_pk_referenced_before) + total_dropped = len(dereferenced_pk_parent) + drop_proportion = total_dropped / total_referenced + + parent_table = parent_table[~parent_table[parent_primary_key].isin(dereferenced_pk_parent)] + unreferenced_data = parent_table[ + ~parent_table[parent_primary_key].isin(parent_pk_referenced_before) + ] + + # Randomly drop a proportional amount of never-referenced rows + unreferenced_data_to_drop = unreferenced_data.sample(frac=drop_proportion) + parent_table = parent_table.drop(unreferenced_data_to_drop.index) + if parent_table.empty: + raise InvalidDataError([ + f"All references in table '{parent_primary_key}' are unknown and must be dropped." + 'Try providing different data for this table.' + ]) + + return parent_table + + +def _subsample_ancestors(data, metadata, table, primary_keys_referenced): + """Subsample the ancestors of the table. + + The strategy here is to recursively subsample the direct parents of the table until the + root tables are reached. + + Args: + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. + table (str): + Name of the table. + primary_keys_referenced (dict): + Dictionary that maps the table name to a set of their primary keys referenced + before any subsampling. + """ + relationships = metadata.relationships + pk_referenced = _get_primary_keys_referenced(data, metadata) + direct_relationships = _get_relationships_for_child(relationships, table) + direct_parents = {rel['parent_table_name'] for rel in direct_relationships} + for parent in sorted(direct_parents): + parent_primary_key = metadata.tables[parent].primary_key + pk_referenced_before = primary_keys_referenced[parent] + dereferenced_primary_keys = pk_referenced_before - pk_referenced[parent] + data[parent] = _subsample_parent( + data[parent], parent_primary_key, pk_referenced_before, + dereferenced_primary_keys + ) + if dereferenced_primary_keys: + primary_keys_referenced[parent] = pk_referenced[parent] + + _subsample_ancestors(data, metadata, parent, primary_keys_referenced) + + +def _subsample_data(data, metadata, main_table_name, num_rows): + """Subsample multi-table table based on a table and a number of rows. + + The strategy is to: + - Subsample the disconnected roots tables by keeping a similar proportion of data + than the main table. Ensure referential integrity. + - Subsample the main table and its descendants to ensure referential integrity. + - Subsample the ancestors of the main table by removing primary key rows that are no longer + referenced by the descendants and some unreferenced rows. + + Args: + metadata (MultiTableMetadata): + Metadata of the datasets. + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + main_table_name (str): + Name of the main table. + num_rows (int): + Number of rows to keep in the main table. + + Returns: + dict: + Dictionary with the subsampled dataframes. + """ + result = deepcopy(data) + primary_keys_referenced = _get_primary_keys_referenced(result, metadata) + ratio_to_keep = num_rows / len(result[main_table_name]) + try: + _validate_foreign_keys_not_null(metadata, result) + except SynthesizerInputError: + warnings.warn( + 'The data contains null values in foreign key columns. ' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + ' to drop these rows before using ``get_random_subset``.' + ) + + try: + _subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep) + _subsample_table_and_descendants(result, metadata, main_table_name, num_rows) + _subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced) + _drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys + except InvalidDataError as error: + if 'All references in table' not in str(error.args[0]): + raise error + else: + raise SamplingError( + f'Subsampling {main_table_name} with {num_rows} rows leads to some empty tables. ' + 'Please try again with a bigger number of rows.' + ) + + return result + + +def _print_subsample_summary(data_before, data_after): + """Print the summary of the subsampled data.""" + tables = sorted(data_before.keys()) + summary = pd.DataFrame({ + 'Table Name': tables, + '# Rows (Before)': [len(data_before[table]) for table in tables], + '# Rows (After)': [ + len(data_after[table]) if table in data_after else 0 for table in tables + ] + }) + subsample_rows = 100 * (1 - summary['# Rows (After)'].sum() / summary['# Rows (Before)'].sum()) + message = [f'Success! Your subset has {round(subsample_rows)}% less rows than the original.\n'] + message.append(summary.to_string(index=False)) + print('\n'.join(message)) # noqa: T001 diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index deb5dfba9..139303a98 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -1,5 +1,6 @@ """Utility functions.""" import sys +from copy import deepcopy import pandas as pd @@ -8,8 +9,8 @@ from sdv.metadata.errors import InvalidMetadataError from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS from sdv.multi_table.utils import ( - _get_relationships_for_child, _get_rows_to_drop, _get_total_estimated_columns, - _print_simplified_schema_summary, _simplify_data, _simplify_metadata) + _drop_rows, _get_total_estimated_columns, _print_simplified_schema_summary, + _print_subsample_summary, _simplify_data, _simplify_metadata, _subsample_data) def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): @@ -54,23 +55,8 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr return data except (InvalidDataError, SynthesizerInputError): - result = data.copy() - table_to_idx_to_drop = _get_rows_to_drop(metadata, result) - for table in table_names: - idx_to_drop = table_to_idx_to_drop[table] - result[table] = result[table].drop(idx_to_drop) - if drop_missing_values: - relationships = _get_relationships_for_child(metadata.relationships, table) - for relationship in relationships: - child_column = relationship['child_foreign_key'] - result[table] = result[table].dropna(subset=[child_column]) - - if result[table].empty: - raise InvalidDataError([ - f"All references in table '{table}' are unknown and must be dropped." - 'Try providing different data for this table.' - ]) - + result = deepcopy(data) + _drop_rows(result, metadata, drop_missing_values) if verbose: summary_table['# Invalid Rows'] = [ len(data[table]) - len(result[table]) for table in table_names @@ -83,7 +69,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr return result -def simplify_schema(data, metadata): +def simplify_schema(data, metadata, verbose=True): """Simplify the schema of the data and metadata. This function simplifies the schema of the data and metadata by: @@ -99,6 +85,9 @@ def simplify_schema(data, metadata): table (pandas.DataFrame). metadata (MultiTableMetadata): Metadata of the datasets. + verbose (bool): + If True, print information about the simplification process. + Defaults to True. Returns: tuple: @@ -127,6 +116,71 @@ def simplify_schema(data, metadata): simple_metadata = _simplify_metadata(metadata) simple_data = _simplify_data(data, simple_metadata) - _print_simplified_schema_summary(data, simple_data) + if verbose: + _print_simplified_schema_summary(data, simple_data) return simple_data, simple_metadata + + +def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True): + """Subsample multi-table table based on a table and a number of rows. + + The strategy is to: + - Subsample the disconnected roots tables by keeping a similar proportion of data + than the main table. Ensure referential integrity. + - Subsample the main table and its descendants to ensure referential integrity. + - Subsample the ancestors of the main table by removing primary key rows that are no longer + referenced by the descendants and drop also some unreferenced rows. + + Args: + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. + main_table_name (str): + Name of the main table. + num_rows (int): + Number of rows to keep in the main table. + verbose (bool): + If True, print information about the subsampling process. + Defaults to True. + + Returns: + dict: + Dictionary with the subsampled dataframes. + """ + try: + error_message = ( + 'The provided data/metadata combination is not valid.' + ' Please make sure that the data/metadata combination is valid' + ' before trying to simplify the schema.' + ) + metadata.validate() + metadata.validate_data(data) + except InvalidMetadataError as error: + raise InvalidMetadataError(error_message) from error + except InvalidDataError as error: + raise InvalidDataError([error_message]) from error + + error_message_num_rows = ( + '``num_rows`` must be a positive integer.' + ) + if not isinstance(num_rows, (int, float)) or num_rows != int(num_rows): + raise ValueError(error_message_num_rows) + + if num_rows <= 0: + raise ValueError(error_message_num_rows) + + if len(data[main_table_name]) <= num_rows: + if verbose: + _print_subsample_summary(data, data) + + return data + + result = _subsample_data(data, metadata, main_table_name, num_rows) + if verbose: + _print_subsample_summary(data, result) + + metadata.validate_data(result) + return result diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 981b8798e..7aec94243 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -1,5 +1,6 @@ import re from copy import deepcopy +from unittest.mock import Mock import numpy as np import pandas as pd @@ -10,7 +11,7 @@ from sdv.metadata import MultiTableMetadata from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer from sdv.multi_table.utils import _get_total_estimated_columns -from sdv.utils.poc import drop_unknown_references, simplify_schema +from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema @pytest.fixture @@ -257,3 +258,93 @@ def test_simplify_schema_big_demo_datasets(): estimate_column_after = _get_total_estimated_columns(metadata_simplify) assert estimate_column_before > MAX_NUMBER_OF_COLUMNS assert estimate_column_after <= MAX_NUMBER_OF_COLUMNS + + +@pytest.mark.parametrize( + ('dataset_name', 'main_table_1', 'main_table_2', 'num_rows_1', 'num_rows_2'), + [ + ('AustralianFootball_v1', 'matches', 'players', 1000, 1000), + ('MuskSmall_v1', 'molecule', 'conformation', 50, 150), + ('NBA_v1', 'Team', 'Actions', 10, 200), + ('NCAA_v1', 'tourney_slots', 'tourney_compact_results', 1000, 1000), + ] +) +def test_get_random_subset(dataset_name, main_table_1, main_table_2, num_rows_1, num_rows_2): + """Test ``get_random_subset`` end to end. + + The goal here is test that the function works for various schema and also by subsampling + different main tables. + + For `AustralianFootball_v1` (parent with child and grandparent): + - main table 1 = `matches` which is the child of `teams` and the parent of `match_stats`. + - main table 2 = `players` which is the parent of `matches`. + + For `MuskSmall_v1` (1 parent - 1 child relationship): + - main table 1 = `molecule` which is the parent of `conformation`. + - main table 2 = `conformation` which is the child of `molecule`. + + For `NBA_v1` (child with parents and grandparent): + - main table 1 = `Team` which is the root table. + - main table 2 = `Actions` which is the last child. It has relationships with `Game` and `Team` + and `Player`. + + For `NCAA_v1` (child with multiple parents): + - main table 1 = `tourney_slots` which is only the child of `seasons`. + - main table 2 = `tourney_compact_results` which is the child of `teams` with two relationships + and of `seasons` with one relationship. + """ + # Setup + real_data, metadata = download_demo('multi_table', dataset_name) + + # Run + result_1 = get_random_subset(real_data, metadata, main_table_1, num_rows_1, verbose=False) + result_2 = get_random_subset(real_data, metadata, main_table_2, num_rows_2, verbose=False) + + # Assert + assert len(result_1[main_table_1]) == num_rows_1 + assert len(result_2[main_table_2]) == num_rows_2 + + +def test_get_random_subset_disconnected_schema(): + """Test ``get_random_subset`` end to end for a disconnected schema. + + Here we break the schema so there is only parent-child relationships between + `Player`-`Action` and `Team`-`Game`. + The part that is not connected to the main table (`Player`) should be subsampled also + in a similar proportion. + """ + # Setup + real_data, metadata = download_demo('multi_table', 'NBA_v1') + metadata.remove_relationship('Game', 'Actions') + metadata.remove_relationship('Team', 'Actions') + metadata.validate = Mock() + metadata.validate_data = Mock() + proportion_to_keep = 0.6 + num_rows_to_keep = int(len(real_data['Player']) * proportion_to_keep) + + # Run + result = get_random_subset(real_data, metadata, 'Player', num_rows_to_keep, verbose=False) + + # Assert + assert len(result['Player']) == num_rows_to_keep + assert len(result['Team']) == int(len(real_data['Team']) * proportion_to_keep) + + +def test_get_random_subset_with_missing_values(metadata, data): + """Test ``get_random_subset`` when there is missing values in the foreign keys.""" + # Setup + data = deepcopy(data) + data['child'].loc[4, 'parent_id'] = np.nan + expected_warning = re.escape( + 'The data contains null values in foreign key columns. ' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + ' to drop these rows before using ``get_random_subset``.' + ) + + # Run + with pytest.warns(UserWarning, match=expected_warning): + cleaned_data = get_random_subset(data, metadata, 'child', 3) + + # Assert + assert len(cleaned_data['child']) == 3 + assert not pd.isna(cleaned_data['child']['parent_id']).any() diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index cbf090822..3917a993f 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -3,15 +3,21 @@ from copy import deepcopy from unittest.mock import Mock, call, patch +import numpy as np import pandas as pd +import pytest +from sdv.errors import InvalidDataError, SamplingError from sdv.metadata import MultiTableMetadata from sdv.multi_table.utils import ( - _get_all_descendant_per_root_at_order_n, _get_columns_to_drop_child, _get_n_order_descendants, - _get_num_column_to_drop, _get_relationships_for_child, _get_relationships_for_parent, - _get_rows_to_drop, _get_total_estimated_columns, _print_simplified_schema_summary, + _drop_rows, _get_all_descendant_per_root_at_order_n, _get_ancestors, + _get_columns_to_drop_child, _get_disconnected_roots_from_table, _get_n_order_descendants, + _get_nan_fk_indices_table, _get_num_column_to_drop, _get_primary_keys_referenced, + _get_relationships_for_child, _get_relationships_for_parent, _get_rows_to_drop, + _get_total_estimated_columns, _print_simplified_schema_summary, _print_subsample_summary, _simplify_child, _simplify_children, _simplify_data, _simplify_grandchildren, - _simplify_metadata, _simplify_relationships_and_tables) + _simplify_metadata, _simplify_relationships_and_tables, _subsample_ancestors, _subsample_data, + _subsample_disconnected_roots, _subsample_parent, _subsample_table_and_descendants) def test__get_relationships_for_child(): @@ -116,7 +122,7 @@ def test__get_rows_to_drop(): } # Run - result = _get_rows_to_drop(metadata, data) + result = _get_rows_to_drop(data, metadata) # Assert expected_result = defaultdict(set, { @@ -127,6 +133,335 @@ def test__get_rows_to_drop(): assert result == expected_result +def test__get_nan_fk_indices_table(): + """Test the ``_get_nan_fk_indices_table`` method.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + data = { + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [np.nan, 1, 2, 2, np.nan], + 'child_foreign_key': [9, np.nan, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + # Run + result = _get_nan_fk_indices_table(data, relationships, 'grandchild') + + # Assert + assert result == {0, 1, 4} + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test__drop_rows(mock_get_rows_to_drop): + """Test the ``_drop_rows`` method.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = { + 'parent': Mock(primary_key='id_parent'), + 'child': Mock(primary_key='id_child'), + 'grandchild': Mock(primary_key='id_grandchild') + } + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5], + 'id_child': [5, 6, 7, 8, 9], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6], + 'child_foreign_key': [9, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {4}, + 'grandchild': {0, 2, 4} + }) + + # Run + _drop_rows(data, metadata, False) + + # Assert + mock_get_rows_to_drop.assert_called_once_with(data, metadata) + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [1, 2], + 'child_foreign_key': [5, 6], + 'C': ['No', 'No'] + }, index=[1, 3]) + } + for table_name, table in data.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test_drop_unknown_references_with_nan(mock_get_rows_to_drop): + """Test ``drop_unknown_references`` whith NaNs and drop_missing_values True.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5, None], + 'id_child': [5, 6, 7, 8, 9, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6, 4], + 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }) + } + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {4}, + 'grandchild': {0, 3, 4} + }) + + # Run + _drop_rows(data, metadata, True) + + # Assert + mock_get_rows_to_drop.assert_called_once() + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0., 1., 2., 2.], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [2, 4], + 'child_foreign_key': [5., 4.], + 'C': ['No', 'No'] + }, index=[2, 5]) + } + for table_name, table in data.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop): + """Test ``drop_unknown_references`` with NaNs and drop_missing_values False.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + metadata.validate_data.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5, None], + 'id_child': [5, 6, 7, 8, 9, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6, 4], + 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }) + } + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {4}, + 'grandchild': {0, 3, 4} + }) + + # Run + _drop_rows(data, metadata, drop_missing_values=False) + + # Assert + mock_get_rows_to_drop.assert_called_once() + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0., 1., 2., 2., None], + 'id_child': [5, 6, 7, 8, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3, 5]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [1, 2, 4], + 'child_foreign_key': [np.nan, 5, 4.], + 'C': ['No', 'No', 'No'] + }, index=[1, 2, 5]) + } + for table_name, table in data.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): + """Test ``drop_unknown_references`` when all rows are dropped.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + metadata.validate_data.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5], + 'id_child': [5, 6, 7, 8, 9], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6], + 'child_foreign_key': [9, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {0, 1, 2, 3, 4} + }) + + # Run and Assert + expected_message = re.escape( + 'The provided data does not match the metadata:\n' + "All references in table 'child' are unknown and must be dropped." + 'Try providing different data for this table.' + ) + with pytest.raises(InvalidDataError, match=expected_message): + _drop_rows(data, metadata, False) + + def test__get_n_order_descendants(): """Test the ``_get_n_order_descendants`` method.""" # Setup @@ -194,19 +529,74 @@ def test__get_all_descendant_per_root_at_order_n(): assert result == expected_result -def test__simplify_relationships_and_tables(): - """Test the ``_simplify_relationships`` method.""" +@pytest.mark.parametrize(('table_name', 'expected_result'), [ + ('grandchild', {'child', 'parent', 'grandparent', 'other_root'}), + ('child', {'parent', 'grandparent', 'other_root'}), + ('parent', {'grandparent'}), + ('other_table', {'grandparent'}), + ('grandparent', set()), + ('other_root', set()), +]) +def test__get_ancestors(table_name, expected_result): + """Test the ``_get_ancestors`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ - 'tables': { - 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, - 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, - 'child': {'columns': {'col_3': {'sdtype': 'numerical'}}}, - 'grandchild': {'columns': {'col_4': {'sdtype': 'numerical'}}}, - 'other_table': {'columns': {'col_5': {'sdtype': 'numerical'}}}, - 'other_root': {'columns': {'col_6': {'sdtype': 'numerical'}}}, - }, - 'relationships': [ + relationships = [ + {'parent_table_name': 'grandparent', 'child_table_name': 'parent'}, + {'parent_table_name': 'parent', 'child_table_name': 'child'}, + {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, + {'parent_table_name': 'grandparent', 'child_table_name': 'other_table'}, + {'parent_table_name': 'other_root', 'child_table_name': 'child'}, + ] + + # Run + result = _get_ancestors(relationships, table_name) + + # Assert + assert result == expected_result + + +@pytest.mark.parametrize(('table_name', 'expected_result'), [ + ('grandchild', {'disconnected_root'}), + ('child', {'disconnected_root'}), + ('parent', {'disconnected_root'}), + ('other_table', {'disconnected_root', 'other_root'}), + ('grandparent', {'disconnected_root'}), + ('other_root', {'disconnected_root'}), + ('disconnected_root', {'grandparent', 'other_root'}), + ('disconnect_child', {'grandparent', 'other_root'}), +]) +def test__get_disconnected_roots_from_table(table_name, expected_result): + """Test the ``_get_disconnected_roots_from_table`` method.""" + # Setup + relationships = [ + {'parent_table_name': 'grandparent', 'child_table_name': 'parent'}, + {'parent_table_name': 'parent', 'child_table_name': 'child'}, + {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, + {'parent_table_name': 'grandparent', 'child_table_name': 'other_table'}, + {'parent_table_name': 'other_root', 'child_table_name': 'child'}, + {'parent_table_name': 'disconnected_root', 'child_table_name': 'disconnect_child'}, + ] + + # Run + result = _get_disconnected_roots_from_table(relationships, table_name) + + # Assert + assert result == expected_result + + +def test__simplify_relationships_and_tables(): + """Test the ``_simplify_relationships`` method.""" + # Setup + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, + 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, + 'child': {'columns': {'col_3': {'sdtype': 'numerical'}}}, + 'grandchild': {'columns': {'col_4': {'sdtype': 'numerical'}}}, + 'other_table': {'columns': {'col_5': {'sdtype': 'numerical'}}}, + 'other_root': {'columns': {'col_6': {'sdtype': 'numerical'}}}, + }, + 'relationships': [ {'parent_table_name': 'grandparent', 'child_table_name': 'parent'}, {'parent_table_name': 'parent', 'child_table_name': 'child'}, {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, @@ -857,3 +1247,785 @@ def test__print_simplified_schema_summary(capsys): r'Table 3\s*1\s*0' ) assert expected_output.match(captured.out.strip()) + + +@patch('sdv.multi_table.utils._get_disconnected_roots_from_table') +@patch('sdv.multi_table.utils._drop_rows') +def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roots_from_table): + """Test the ``_subsample_disconnected_roots`` method.""" + # Setup + data = { + 'disconnected_root': pd.DataFrame({ + 'col_1': [1, 2, 3, 4, 5], + 'col_2': [6, 7, 8, 9, 10], + }), + 'grandparent': pd.DataFrame({ + 'col_3': [1, 2, 3, 4, 5], + 'col_4': [6, 7, 8, 9, 10], + }), + 'other_root': pd.DataFrame({ + 'col_5': [1, 2, 3, 4, 5], + 'col_6': [6, 7, 8, 9, 10], + }), + 'child': pd.DataFrame({ + 'col_7': [1, 2, 3, 4, 5], + 'col_8': [6, 7, 8, 9, 10], + }), + 'other_table': pd.DataFrame({ + 'col_9': [1, 2, 3, 4, 5], + 'col_10': [6, 7, 8, 9, 10], + }), + 'parent': pd.DataFrame({ + 'col_11': [1, 2, 3, 4, 5], + 'col_12': [6, 7, 8, 9, 10], + }), + } + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'disconnected_root': { + 'columns': { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'numerical'}, + }, + }, + 'grandparent': { + 'columns': { + 'col_3': {'sdtype': 'numerical'}, + 'col_4': {'sdtype': 'numerical'}, + }, + }, + 'other_root': { + 'columns': { + 'col_5': {'sdtype': 'numerical'}, + 'col_6': {'sdtype': 'numerical'}, + }, + }, + 'child': { + 'columns': { + 'col_7': {'sdtype': 'numerical'}, + 'col_8': {'sdtype': 'numerical'}, + }, + }, + 'other_table': { + 'columns': { + 'col_9': {'sdtype': 'numerical'}, + 'col_10': {'sdtype': 'numerical'}, + }, + }, + 'parent': { + 'columns': { + 'col_11': {'sdtype': 'numerical'}, + 'col_12': {'sdtype': 'numerical'}, + }, + }, + }, + 'relationships': [ + {'parent_table_name': 'grandparent', 'child_table_name': 'parent'}, + {'parent_table_name': 'parent', 'child_table_name': 'child'}, + {'parent_table_name': 'child', 'child_table_name': 'grandchild'}, + {'parent_table_name': 'grandparent', 'child_table_name': 'other_table'}, + {'parent_table_name': 'other_root', 'child_table_name': 'child'}, + {'parent_table_name': 'disconnected_root', 'child_table_name': 'disconnect_child'}, + ] + }) + mock_get_disconnected_roots_from_table.return_value = {'grandparent', 'other_root'} + ratio_to_keep = 0.6 + expected_result = deepcopy(data) + + # Run + _subsample_disconnected_roots(data, metadata, 'disconnected_root', ratio_to_keep) + + # Assert + mock_get_disconnected_roots_from_table.assert_called_once_with( + metadata.relationships, 'disconnected_root' + ) + mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) + for table_name in metadata.tables: + if table_name not in {'grandparent', 'other_root'}: + pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) + else: + assert len(data[table_name]) == 3 + + +@patch('sdv.multi_table.utils._drop_rows') +@patch('sdv.multi_table.utils._get_nan_fk_indices_table') +def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, + mock_drop_rows): + """Test the ``_subsample_table_and_descendants`` method.""" + # Setup + data = { + 'grandparent': pd.DataFrame({ + 'col_1': [1, 2, 3, 4, 5], + 'col_2': [6, 7, 8, 9, 10], + }), + 'parent': pd.DataFrame({ + 'col_3': [1, 2, 3, 4, 5], + 'col_4': [6, 7, 8, 9, 10], + }), + 'child': pd.DataFrame({ + 'col_5': [1, 2, 3, 4, 5], + 'col_6': [6, 7, 8, 9, 10], + }), + 'grandchild': pd.DataFrame({ + 'col_7': [1, 2, 3, 4, 5], + 'col_8': [6, 7, 8, 9, 10], + }), + } + mock_get_nan_fk_indices_table.return_value = {0} + metadata = Mock() + metadata.relationships = Mock() + + # Run + _subsample_table_and_descendants(data, metadata, 'parent', 3) + + # Assert + mock_get_nan_fk_indices_table.assert_called_once_with( + data, metadata.relationships, 'parent' + ) + mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) + assert len(data['parent']) == 3 + + +@patch('sdv.multi_table.utils._get_nan_fk_indices_table') +def test__subsample_table_and_descendants_nan_fk(mock_get_nan_fk_indices_table): + """Test the ``_subsample_table_and_descendants`` when there are too many NaN foreign keys.""" + # Setup + data = {'parent': [1, 2, 3, 4, 5, 6]} + mock_get_nan_fk_indices_table.return_value = {0, 1, 2, 3, 4} + metadata = Mock() + metadata.relationships = Mock() + expected_message = re.escape( + "Referential integrity cannot be reached for table 'parent' while keeping " + '3 rows. Please try again with a bigger number of rows.' + ) + + # Run + with pytest.raises(SamplingError, match=expected_message): + _subsample_table_and_descendants(data, metadata, 'parent', 3) + + # Assert + mock_get_nan_fk_indices_table.assert_called_once_with( + data, metadata.relationships, 'parent' + ) + + +def test__get_primary_keys_referenced(): + """Test the ``_get_primary_keys_referenced`` method.""" + data = { + 'grandparent': pd.DataFrame({ + 'pk_gp': [1, 2, 3, 4, 5], + 'col_1': [6, 7, 8, 9, 10], + }), + 'parent': pd.DataFrame({ + 'fk_gp': [1, 2, 2, 3, 1], + 'pk_p': [11, 12, 13, 14, 15], + 'col_2': [16, 17, 18, 19, 20], + }), + 'child': pd.DataFrame({ + 'fk_gp': [5, 2, 2, 3, 1], + 'fk_p_1': [11, 11, 11, 11, 11], + 'fk_p_2': [12, 12, 12, 12, 12], + 'pk_c': [21, 22, 23, 24, 25], + 'col_3': [26, 27, 28, 29, 30], + }), + 'grandchild': pd.DataFrame({ + 'fk_p_3': [13, 14, 13, 13, 13], + 'fk_p_4': [14, 13, 14, 14, 14], + 'fk_c': [21, 22, 23, 24, 25], + 'col_4': [36, 37, 38, 39, 40], + }), + } + + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'grandparent': { + 'columns': { + 'pk_gp': {'type': 'id'}, + 'col_1': {'type': 'numerical'}, + }, + 'primary_key': 'pk_gp' + }, + 'parent': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'pk_p': {'type': 'id'}, + 'col_2': {'type': 'numerical'}, + }, + 'primary_key': 'pk_p' + }, + 'child': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'fk_p_1': {'type': 'id'}, + 'fk_p_2': {'type': 'id'}, + 'pk_c': {'type': 'id'}, + 'col_3': {'type': 'numerical'}, + }, + 'primary_key': 'pk_c' + }, + 'grandchild': { + 'columns': { + 'fk_p_3': {'type': 'id'}, + 'fk_p_4': {'type': 'id'}, + 'fk_c': {'type': 'id'}, + 'col_4': {'type': 'numerical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'parent', + 'parent_primary_key': 'pk_gp', + 'child_foreign_key': 'fk_gp', + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_1', + }, + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_gp', + + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_2', + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'pk_c', + 'child_foreign_key': 'fk_c', + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_3', + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_4', + } + ] + }) + + # Run + result = _get_primary_keys_referenced(data, metadata) + + # Assert + expected_result = { + 'grandparent': {1, 2, 3, 5}, + 'parent': {11, 12, 13, 14}, + 'child': {21, 22, 23, 24, 25}, + } + assert result == expected_result + + +def test__subsample_parent_all_reeferenced_before(): + """Test the ``_subsample_parent`` when all primary key were referenced before. + + Here the primary keys ``4`` and ``5`` are no longer referenced and should be dropped. + """ + # Setup + data = { + 'parent': pd.DataFrame({ + 'pk_p': [1, 2, 3, 4, 5], + 'col_2': [16, 17, 18, 19, 20], + }), + 'child': pd.DataFrame({ + 'fk_p_1': [1, 2, 2, 2, 3], + }), + } + + pk_referenced_before = defaultdict(set) + pk_referenced_before['parent'] = {1, 2, 3, 4, 5} + unreferenced_pk = {4, 5} + + # Run + data['parent'] = _subsample_parent( + data['parent'], 'pk_p', pk_referenced_before['parent'], unreferenced_pk + ) + + # Assert + expected_result = { + 'parent': pd.DataFrame({ + 'pk_p': [1, 2, 3], + 'col_2': [16, 17, 18], + }), + 'child': pd.DataFrame({ + 'fk_p_1': [1, 2, 2, 2, 3], + }), + } + pd.testing.assert_frame_equal(data['parent'], expected_result['parent']) + pd.testing.assert_frame_equal(data['child'], expected_result['child']) + + +def test__subsample_parent_not_all_referenced_before(): + """Test the ``_subsample_parent`` when not all primary key were referenced before. + + In this example: + - The primary key ``5`` is no longer referenced and should be dropped. + - One unreferenced primary key must be dropped to keep the same ratio of + referenced/unreferenced primary keys. + """ + # Setup + data = { + 'parent': pd.DataFrame({ + 'pk_p': [1, 2, 3, 4, 5, 6, 7, 8], + 'col_2': [16, 17, 18, 19, 20, 21, 22, 23], + }), + 'child': pd.DataFrame({ + 'fk_p_1': [1, 2, 2, 2, 3], + }), + } + + pk_referenced_before = defaultdict(set) + pk_referenced_before['parent'] = {1, 2, 3, 5} + unreferenced_pk = {5} + + # Run + data['parent'] = _subsample_parent( + data['parent'], 'pk_p', pk_referenced_before['parent'], unreferenced_pk + ) + + # Assert + assert len(data['parent']) == 6 + assert set(data['parent']['pk_p']).issubset({ + 1, 2, 3, 4, 6, 7, 8 + }) + + +def test__subsample_ancestors(): + """Test the ``_subsample_ancestors`` method.""" + # Setup + data = { + 'grandparent': pd.DataFrame({ + 'pk_gp': [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 + ], + 'col_1': [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', + 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + ], + }), + 'parent': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 4, 9, 6], + 'pk_p': [11, 12, 13, 14, 15, 16], + 'col_2': ['k', 'l', 'm', 'n', 'o', 'p'], + }), + 'child': pd.DataFrame({ + 'fk_gp': [4, 5, 6, 7, 8], + 'fk_p_1': [11, 11, 11, 14, 11], + 'fk_p_2': [12, 12, 12, 12, 15], + 'pk_c': [21, 22, 23, 24, 25], + 'col_3': ['q', 'r', 's', 't', 'u'], + }), + 'grandchild': pd.DataFrame({ + 'fk_p_3': [11, 12, 13, 11, 13], + 'fk_c': [21, 22, 23, 21, 22], + 'col_4': [36, 37, 38, 39, 40], + }), + } + + primary_key_referenced = { + 'grandparent': {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 'parent': {11, 12, 13, 14, 15}, + 'child': {21, 22, 23, 24, 25}, + } + + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'grandparent': { + 'columns': { + 'pk_gp': {'type': 'id'}, + 'col_1': {'type': 'numerical'}, + }, + 'primary_key': 'pk_gp' + }, + 'parent': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'pk_p': {'type': 'id'}, + 'col_2': {'type': 'numerical'}, + }, + 'primary_key': 'pk_p' + }, + 'child': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'fk_p_1': {'type': 'id'}, + 'fk_p_2': {'type': 'id'}, + 'pk_c': {'type': 'id'}, + 'col_3': {'type': 'numerical'}, + }, + 'primary_key': 'pk_c' + }, + 'grandchild': { + 'columns': { + 'fk_p_3': {'type': 'id'}, + 'fk_p_4': {'type': 'id'}, + 'fk_c': {'type': 'id'}, + 'col_4': {'type': 'numerical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'parent', + 'parent_primary_key': 'pk_gp', + 'child_foreign_key': 'fk_gp', + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_1', + }, + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_gp', + + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_2', + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'pk_c', + 'child_foreign_key': 'fk_c', + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_3', + } + ] + }) + + # Run + _subsample_ancestors(data, metadata, 'grandchild', primary_key_referenced) + + # Assert + expected_result = { + 'parent': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [11, 12, 13, 16], + 'col_2': ['k', 'l', 'm', 'p'], + }, index=[0, 1, 2, 5]), + 'child': pd.DataFrame({ + 'fk_gp': [4, 5, 6], + 'fk_p_1': [11, 11, 11], + 'fk_p_2': [12, 12, 12], + 'pk_c': [21, 22, 23], + 'col_3': ['q', 'r', 's'], + }, index=[0, 1, 2]), + 'grandchild': pd.DataFrame({ + 'fk_p_3': [11, 12, 13, 11, 13], + 'fk_c': [21, 22, 23, 21, 22], + 'col_4': [36, 37, 38, 39, 40], + }, index=[0, 1, 2, 3, 4]), + } + assert len(data['grandparent']) == 14 + assert set(data['grandparent']['pk_gp']).issubset( + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20 + } + ) + for table_name in ['parent', 'child', 'grandchild']: + pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) + + +def test__subsample_ancestors_schema_diamond_shape(): + """Test the ``_subsample_ancestors`` method with a diamond shape schema.""" + # Setup + data = { + 'grandparent': pd.DataFrame({ + 'pk_gp': [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 + ], + 'col_1': [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', + 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + ], + }), + 'parent_1': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 4, 5, 6], + 'pk_p': [21, 22, 23, 24, 25, 26], + 'col_2': ['k', 'l', 'm', 'n', 'o', 'p'], + }), + 'parent_2': pd.DataFrame({ + 'fk_gp': [7, 8, 9, 10, 11], + 'pk_p': [31, 32, 33, 34, 35], + 'col_3': ['k', 'l', 'm', 'n', 'o'], + }), + 'child': pd.DataFrame({ + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }) + } + + primary_key_referenced = { + 'grandparent': {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + 'parent_1': {21, 22, 23, 24, 25}, + 'parent_2': {31, 32, 33, 34, 35}, + } + + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'grandparent': { + 'columns': { + 'pk_gp': {'type': 'id'}, + 'col_1': {'type': 'numerical'}, + }, + 'primary_key': 'pk_gp' + }, + 'parent_1': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'pk_p': {'type': 'id'}, + 'col_2': {'type': 'numerical'}, + }, + 'primary_key': 'pk_p' + }, + 'parent_2': { + 'columns': { + 'fk_gp': {'type': 'id'}, + 'pk_p': {'type': 'id'}, + 'col_3': {'type': 'numerical'}, + }, + 'primary_key': 'pk_p' + }, + 'child': { + 'columns': { + 'fk_p_1': {'type': 'id'}, + 'fk_p_2': {'type': 'id'}, + 'col_4': {'type': 'numerical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'parent_1', + 'parent_primary_key': 'pk_gp', + 'child_foreign_key': 'fk_gp', + }, + { + 'parent_table_name': 'grandparent', + 'child_table_name': 'parent_2', + 'parent_primary_key': 'pk_gp', + 'child_foreign_key': 'fk_gp', + }, + { + 'parent_table_name': 'parent_1', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_1', + }, + { + 'parent_table_name': 'parent_2', + 'child_table_name': 'child', + 'parent_primary_key': 'pk_p', + 'child_foreign_key': 'fk_p_2', + }, + ] + }) + + # Run + _subsample_ancestors(data, metadata, 'child', primary_key_referenced) + + # Assert + expected_result = { + 'parent_1': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [21, 22, 23, 26], + 'col_2': ['k', 'l', 'm', 'p'], + }, index=[0, 1, 2, 5]), + 'parent_2': pd.DataFrame({ + 'fk_gp': [7, 8, 9, 10], + 'pk_p': [31, 32, 33, 34], + 'col_3': ['k', 'l', 'm', 'n'], + }, index=[0, 1, 2, 3]), + 'child': pd.DataFrame({ + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }, index=[0, 1, 2, 3, 4]), + } + assert len(data['grandparent']) == 14 + assert set(data['grandparent']['pk_gp']).issubset( + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20 + } + ) + for table_name in ['parent_1', 'parent_2', 'child']: + pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) + + +@patch('sdv.multi_table.utils._subsample_disconnected_roots') +@patch('sdv.multi_table.utils._subsample_table_and_descendants') +@patch('sdv.multi_table.utils._subsample_ancestors') +@patch('sdv.multi_table.utils._get_primary_keys_referenced') +@patch('sdv.multi_table.utils._drop_rows') +@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') +def test__subsample_data( + mock_validate_foreign_keys_not_null, + mock_drop_rows, + mock_get_primary_keys_referenced, + mock_subsample_ancestors, + mock_subsample_table_and_descendants, + mock_subsample_disconnected_roots +): + """Test the ``_subsample_data`` method.""" + # Setup + data = { + 'main_table': [1] * 10, + } + metadata = Mock() + num_rows = 5 + main_table = 'main_table' + primary_key_reference = { + 'main_table': {1, 2, 4} + } + mock_get_primary_keys_referenced.return_value = primary_key_reference + + # Run + result = _subsample_data(data, metadata, main_table, num_rows) + + # Assert + mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data) + mock_get_primary_keys_referenced.assert_called_once_with(data, metadata) + mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5) + mock_subsample_table_and_descendants.assert_called_once_with( + data, metadata, main_table, num_rows + ) + mock_subsample_ancestors.assert_called_once_with( + data, metadata, main_table, primary_key_reference + ) + mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) + assert result == data + + +@patch('sdv.multi_table.utils._subsample_disconnected_roots') +@patch('sdv.multi_table.utils._get_primary_keys_referenced') +@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') +def test__subsample_data_empty_dataset( + mock_validate_foreign_keys_not_null, + mock_get_primary_keys_referenced, + mock_subsample_disconnected_roots +): + """Test the ``subsample_data`` method when a dataset is empty.""" + # Setup + data = { + 'main_table': [1] * 10, + } + metadata = Mock() + num_rows = 5 + main_table = 'main_table' + mock_subsample_disconnected_roots.side_effect = InvalidDataError('All references in table') + + # Run and Assert + expected_message = re.escape( + 'Subsampling main_table with 5 rows leads to some empty tables. ' + 'Please try again with a bigger number of rows.' + ) + with pytest.raises(SamplingError, match=expected_message): + _subsample_data(data, metadata, main_table, num_rows) + + +def test__print_subsample_summary(capsys): + """Test the ``_print_subsample_summary`` method.""" + # Setup + data_before = { + 'grandparent': pd.DataFrame({ + 'pk_gp': [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 + ], + 'col_1': [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', + 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't' + ], + }), + 'parent_1': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 4, 5, 6], + 'pk_p': [21, 22, 23, 24, 25, 26], + 'col_2': ['k', 'l', 'm', 'n', 'o', 'p'], + }), + 'parent_2': pd.DataFrame({ + 'fk_gp': [7, 8, 9, 10, 11], + 'pk_p': [31, 32, 33, 34, 35], + 'col_3': ['k', 'l', 'm', 'n', 'o'], + }), + 'child': pd.DataFrame({ + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }) + } + + data_after = { + 'grandparent': pd.DataFrame({ + 'pk_gp': [ + 1, 2, 3, 6, 7, 8, 9, 10, 14, 15, + 16, 17, 18, 20 + ], + 'col_1': [ + 'a', 'b', 'c', 'f', 'g', 'h', 'i', 'j', 'n', 'o', + 'p', 'q', 'r', 't' + ], + }, index=[0, 1, 2, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 19]), + 'parent_1': pd.DataFrame({ + 'fk_gp': [1, 2, 3, 6], + 'pk_p': [21, 22, 23, 26], + 'col_2': ['k', 'l', 'm', 'p'], + }, index=[0, 1, 2, 5]), + 'parent_2': pd.DataFrame({ + 'fk_gp': [7, 8, 9, 10], + 'pk_p': [31, 32, 33, 34], + 'col_3': ['k', 'l', 'm', 'n'], + }, index=[0, 1, 2, 3]), + 'child': pd.DataFrame({ + 'fk_p_1': [21, 22, 23, 23, 23], + 'fk_p_2': [31, 32, 33, 34, 34], + 'col_4': ['q', 'r', 's', 't', 'u'], + }, index=[0, 1, 2, 3, 4]), + } + + # Run + _print_subsample_summary(data_before, data_after) + captured = capsys.readouterr() + + # Assert + expected_output = re.compile( + r'Success! Your subset has 25% less rows than the original\.\s*' + r'Table Name\s*#\s*Rows \(Before\)\s*#\s*Rows \(After\)\s*' + r'child\s*5\s*5\s*' + r'grandparent\s*20\s*14\s*' + r'parent_1\s*6\s*4\s*' + r'parent_2\s*5\s*4' + ) + assert expected_output.match(captured.out.strip()) diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index ba93205eb..089ef9c4e 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -9,12 +9,11 @@ from sdv.errors import InvalidDataError from sdv.metadata import MultiTableMetadata from sdv.metadata.errors import InvalidMetadataError -from sdv.utils.poc import drop_unknown_references, simplify_schema +from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema -@patch('sys.stdout.write') -@patch('sdv.utils.poc._get_rows_to_drop') -def test_drop_unknown_references(mock_get_rows_to_drop, mock_stdout_write): +@patch('sdv.utils.poc._drop_rows') +def test_drop_unknown_references(mock_drop_rows): """Test ``drop_unknown_references``.""" # Setup relationships = [ @@ -59,27 +58,17 @@ def test_drop_unknown_references(mock_get_rows_to_drop, mock_stdout_write): 'C': ['Yes', 'No', 'No', 'No', 'No'] }) } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 2, 4} - }) + + def _drop_rows(data, metadata, drop_missing_values): + data['child'] = data['child'].iloc[:4] + data['grandchild'] = data['grandchild'].iloc[[1, 3]] + + mock_drop_rows.side_effect = _drop_rows # Run result = drop_unknown_references(data, metadata) # Assert - expected_pattern = re.compile( - r'Success! All foreign keys have referential integrity\.\s*' - r'Table Name\s*#\s*Rows \(Original\)\s*#\s*Invalid Rows\s*#\s*Rows \(New\)\s*' - r'child\s*5\s*1\s*4\s*' - r'grandchild\s*5\s*3\s*2\s*' - r'parent\s*5\s*0\s*5' - ) - output = mock_stdout_write.call_args[0][0] - assert expected_pattern.match(output) - metadata.validate.assert_called_once() - metadata.validate_data.assert_called_once_with(data) - mock_get_rows_to_drop.assert_called_once() expected_result = { 'parent': pd.DataFrame({ 'id_parent': [0, 1, 2, 3, 4], @@ -96,6 +85,9 @@ def test_drop_unknown_references(mock_get_rows_to_drop, mock_stdout_write): 'C': ['No', 'No'] }, index=[1, 3]) } + metadata.validate.assert_called_once() + metadata.validate_data.assert_called_once_with(data) + mock_drop_rows.assert_called_once_with(result, metadata, True) for table_name, table in result.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -145,7 +137,7 @@ def test_drop_unknown_references_valid_data_mock(mock_stdout_write): pd.testing.assert_frame_equal(table, data[table_name]) -@patch('sdv.utils.poc._get_rows_to_drop') +@patch('sdv.multi_table.utils._get_rows_to_drop') @patch('sdv.utils.poc._validate_foreign_keys_not_null') def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_rows_to_drop): """Test ``drop_unknown_references`` whith NaNs and drop_missing_values True.""" @@ -226,7 +218,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r pd.testing.assert_frame_equal(table, expected_result[table_name]) -@patch('sdv.utils.poc._get_rows_to_drop') +@patch('sdv.multi_table.utils._get_rows_to_drop') def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop): """Test ``drop_unknown_references`` with NaNs and drop_missing_values False.""" # Setup @@ -302,7 +294,7 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop pd.testing.assert_frame_equal(table, expected_result[table_name]) -@patch('sdv.utils.poc._get_rows_to_drop') +@patch('sdv.multi_table.utils._get_rows_to_drop') def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): """Test ``drop_unknown_references`` when all rows are dropped.""" # Setup @@ -484,3 +476,139 @@ def test_simplify_schema_invalid_data(): ) with pytest.raises(InvalidDataError, match=expected_message): simplify_schema(real_data, metadata) + + +def test_get_random_subset_invalid_metadata(): + """Test ``get_random_subset`` when the metadata is invalid.""" + # Setup + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'column1': {'sdtype': 'categorical'} + } + } + }, + 'relationships': [ + { + 'parent_table_name': 'table1', + 'child_table_name': 'table2', + 'parent_primary_key': 'column1', + 'child_foreign_key': 'column2' + } + ] + }) + real_data = { + 'table1': pd.DataFrame({'column1': [1, 2, 3]}), + 'table2': pd.DataFrame({'column2': [4, 5, 6]}), + } + + # Run and Assert + expected_message = re.escape( + 'The provided data/metadata combination is not valid. Please make sure that the' + ' data/metadata combination is valid before trying to simplify the schema.' + ) + with pytest.raises(InvalidMetadataError, match=expected_message): + get_random_subset(real_data, metadata, 'table1', 2) + + +def test_get_random_subset_invalid_data(): + """Test ``get_random_subset`` when the data is not valid.""" + # Setup + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'column1': {'sdtype': 'id'} + }, + 'primary_key': 'column1' + }, + 'table2': { + 'columns': { + 'column2': {'sdtype': 'id'} + }, + } + }, + 'relationships': [ + { + 'parent_table_name': 'table1', + 'child_table_name': 'table2', + 'parent_primary_key': 'column1', + 'child_foreign_key': 'column2' + } + ] + }) + real_data = { + 'table1': pd.DataFrame({'column1': [np.nan, 1, 2]}), + 'table2': pd.DataFrame({'column2': [1, 1, 2]}), + } + + # Run and Assert + expected_message = re.escape( + 'The provided data/metadata combination is not valid. Please make sure that the' + ' data/metadata combination is valid before trying to simplify the schema.' + ) + with pytest.raises(InvalidDataError, match=expected_message): + get_random_subset(real_data, metadata, 'table1', 2) + + +def test_get_random_subset_invalid_num_rows(): + """Test ``get_random_subset`` when ``num_rows`` is invalid.""" + # Setup + data = Mock() + metadata = Mock() + + # Run and Assert + with pytest.raises(ValueError, match='``num_rows`` must be a positive integer.'): + get_random_subset(data, metadata, 'table1', -1) + with pytest.raises(ValueError, match='``num_rows`` must be a positive integer.'): + get_random_subset(data, metadata, 'table1', 0) + with pytest.raises(ValueError, match='``num_rows`` must be a positive integer.'): + get_random_subset(data, metadata, 'table1', 0.5) + + +def test_get_random_subset_nothing_to_sample(): + """Test ``get_random_subset`` when there is nothing to sample.""" + # Setup + data = { + 'table1': pd.DataFrame({'column1': [1, 2, 3]}), + 'table2': pd.DataFrame({'column2': [4, 5, 6]}), + } + metadata = Mock() + + # Run + result = get_random_subset(data, metadata, 'table1', 5) + + # Assert + pd.testing.assert_frame_equal(result['table1'], data['table1']) + pd.testing.assert_frame_equal(result['table2'], data['table2']) + + +@patch('sdv.utils.poc._subsample_data') +@patch('sdv.utils.poc._print_subsample_summary') +def test_get_random_subset(mock_print_summary, mock_subsample_data): + """Test ``get_random_subset``.""" + # Setup + data = { + 'table1': pd.DataFrame({'column1': [1, 2, 3, 4, 5]}), + 'table2': pd.DataFrame({'column2': [6, 7, 8, 9, 10]}), + } + metadata = Mock() + output = { + 'table1': pd.DataFrame({'column1': [1, 2, 3]}), + 'table2': pd.DataFrame({'column2': [6, 7, 8]}), + } + mock_subsample_data.return_value = output + + # Run + get_random_subset(data, metadata, 'table1', 3) + result = get_random_subset(data, metadata, 'table2', 3, verbose=False) + + # Assert + pd.testing.assert_frame_equal(result['table1'], output['table1']) + pd.testing.assert_frame_equal(result['table2'], output['table2']) + mock_subsample_data.assert_has_calls([ + ((data, metadata, 'table1', 3),), + ((data, metadata, 'table2', 3),), + ]) + mock_print_summary.call_count == 1 From 25388141122ab3b0a2a4948cbc3d4e9a7214723b Mon Sep 17 00:00:00 2001 From: SDV Team <98988753+sdv-team@users.noreply.github.com> Date: Tue, 30 Apr 2024 15:35:54 -0400 Subject: [PATCH 12/32] Automated Latest Dependency Updates (#1971) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- latest_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latest_requirements.txt b/latest_requirements.txt index 7c9f63bac..9da60a07f 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -5,7 +5,7 @@ deepecho==0.6.0 graphviz==0.20.3 numpy==1.26.4 pandas==2.2.2 +platformdirs==4.2.1 rdt==1.12.0 sdmetrics==0.14.0 tqdm==4.66.2 -platformdirs==4.2.0 From b94cf94915c8f8ddb94306b712b4b8fe0b9e041a Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Wed, 1 May 2024 16:49:12 +0100 Subject: [PATCH 13/32] Move function `drop_unknown_references` from `poc` to be directly under `utils` (#1969) --- sdv/metadata/multi_table.py | 4 +- sdv/multi_table/utils.py | 2 +- sdv/utils/__init__.py | 2 +- sdv/utils/poc.py | 76 +---- sdv/utils/utils.py | 65 +++++ tests/integration/utils/test_poc.py | 97 +------ tests/integration/utils/test_utils.py | 153 ++++++++++ tests/unit/metadata/test_multi_table.py | 6 +- tests/unit/multi_table/test_base.py | 2 +- tests/unit/utils/test_poc.py | 345 +---------------------- tests/unit/utils/test_utils.py | 353 ++++++++++++++++++++++++ 11 files changed, 604 insertions(+), 501 deletions(-) create mode 100644 sdv/utils/utils.py create mode 100644 tests/integration/utils/test_utils.py create mode 100644 tests/unit/utils/test_utils.py diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index b7ccb7046..c7aa10d30 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -821,8 +821,8 @@ def _validate_foreign_keys(self, data): errors.append( f"Error: foreign key column '{relation['child_foreign_key']}' contains " - f'unknown references: {message}. Please use the utility method' - " 'drop_unknown_references' to clean the data." + f'unknown references: {message}. Please use the method' + " 'drop_unknown_references' from sdv.utils to clean the data." ) if errors: diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index cf3bb30e1..a44574157 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -638,7 +638,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows): except SynthesizerInputError: warnings.warn( 'The data contains null values in foreign key columns. ' - 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils' ' to drop these rows before using ``get_random_subset``.' ) diff --git a/sdv/utils/__init__.py b/sdv/utils/__init__.py index bce5a07fd..981fca3a5 100644 --- a/sdv/utils/__init__.py +++ b/sdv/utils/__init__.py @@ -1,6 +1,6 @@ """Utils module.""" -from sdv.utils.poc import drop_unknown_references +from sdv.utils.utils import drop_unknown_references __all__ = ( 'drop_unknown_references', diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index 139303a98..40f3944e9 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -1,72 +1,22 @@ -"""Utility functions.""" -import sys -from copy import deepcopy +"""POC functions to use HMASynthesizer succesfully.""" +import warnings -import pandas as pd - -from sdv._utils import _validate_foreign_keys_not_null -from sdv.errors import InvalidDataError, SynthesizerInputError +from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS from sdv.multi_table.utils import ( - _drop_rows, _get_total_estimated_columns, _print_simplified_schema_summary, - _print_subsample_summary, _simplify_data, _simplify_metadata, _subsample_data) - - -def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): - """Drop rows with unknown foreign keys. + _get_total_estimated_columns, _print_simplified_schema_summary, _print_subsample_summary, + _simplify_data, _simplify_metadata, _subsample_data) +from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references - Args: - data (dict): - Dictionary that maps each table name (string) to the data for that - table (pandas.DataFrame). - metadata (MultiTableMetadata): - Metadata of the datasets. - drop_missing_values (bool): - Boolean describing whether or not to also drop foreign keys with missing values - If True, drop rows with missing values in the foreign keys. - Defaults to True. - verbose (bool): - If True, print information about the rows that are dropped. - Defaults to True. - Returns: - dict: - Dictionary with the dataframes ensuring referential integrity. - """ - success_message = 'Success! All foreign keys have referential integrity.' - table_names = sorted(metadata.tables) - summary_table = pd.DataFrame({ - 'Table Name': table_names, - '# Rows (Original)': [len(data[table]) for table in table_names], - '# Invalid Rows': [0] * len(table_names), - '# Rows (New)': [len(data[table]) for table in table_names] - }) - metadata.validate() - try: - metadata.validate_data(data) - if drop_missing_values: - _validate_foreign_keys_not_null(metadata, data) - - if verbose: - sys.stdout.write( - '\n'.join([success_message, '', summary_table.to_string(index=False)]) - ) - - return data - except (InvalidDataError, SynthesizerInputError): - result = deepcopy(data) - _drop_rows(result, metadata, drop_missing_values) - if verbose: - summary_table['# Invalid Rows'] = [ - len(data[table]) - len(result[table]) for table in table_names - ] - summary_table['# Rows (New)'] = [len(result[table]) for table in table_names] - sys.stdout.write('\n'.join([ - success_message, '', summary_table.to_string(index=False) - ])) - - return result +def drop_unknown_references(data, metadata): + """Wrap the drop_unknown_references function from the utils module.""" + warnings.warn( + "Please access the 'drop_unknown_references' function directly from the sdv.utils module" + 'instead of sdv.utils.poc.', FutureWarning + ) + return utils_drop_unknown_references(data, metadata) def simplify_schema(data, metadata, verbose=True): diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py new file mode 100644 index 000000000..5b3589b1f --- /dev/null +++ b/sdv/utils/utils.py @@ -0,0 +1,65 @@ +"""Utils module.""" +import sys +from copy import deepcopy + +import pandas as pd + +from sdv._utils import _validate_foreign_keys_not_null +from sdv.errors import InvalidDataError, SynthesizerInputError +from sdv.multi_table.utils import _drop_rows + + +def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): + """Drop rows with unknown foreign keys. + + Args: + data (dict): + Dictionary that maps each table name (string) to the data for that + table (pandas.DataFrame). + metadata (MultiTableMetadata): + Metadata of the datasets. + drop_missing_values (bool): + Boolean describing whether or not to also drop foreign keys with missing values + If True, drop rows with missing values in the foreign keys. + Defaults to True. + verbose (bool): + If True, print information about the rows that are dropped. + Defaults to True. + + Returns: + dict: + Dictionary with the dataframes ensuring referential integrity. + """ + success_message = 'Success! All foreign keys have referential integrity.' + table_names = sorted(metadata.tables) + summary_table = pd.DataFrame({ + 'Table Name': table_names, + '# Rows (Original)': [len(data[table]) for table in table_names], + '# Invalid Rows': [0] * len(table_names), + '# Rows (New)': [len(data[table]) for table in table_names] + }) + metadata.validate() + try: + metadata.validate_data(data) + if drop_missing_values: + _validate_foreign_keys_not_null(metadata, data) + + if verbose: + sys.stdout.write( + '\n'.join([success_message, '', summary_table.to_string(index=False)]) + ) + + return data + except (InvalidDataError, SynthesizerInputError): + result = deepcopy(data) + _drop_rows(result, metadata, drop_missing_values) + if verbose: + summary_table['# Invalid Rows'] = [ + len(data[table]) - len(result[table]) for table in table_names + ] + summary_table['# Rows (New)'] = [len(result[table]) for table in table_names] + sys.stdout.write('\n'.join([ + success_message, '', summary_table.to_string(index=False) + ])) + + return result diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 7aec94243..8f4c5fc37 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -7,11 +7,10 @@ import pytest from sdv.datasets.demo import download_demo -from sdv.errors import InvalidDataError from sdv.metadata import MultiTableMetadata from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer from sdv.multi_table.utils import _get_total_estimated_columns -from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema +from sdv.utils.poc import get_random_subset, simplify_schema @pytest.fixture @@ -65,98 +64,6 @@ def data(): } -def test_drop_unknown_references(metadata, data, capsys): - """Test ``drop_unknown_references`` end to end.""" - # Run - expected_message = re.escape( - 'The provided data does not match the metadata:\n' - 'Relationships:\n' - "Error: foreign key column 'parent_id' contains unknown references: (5)" - ". Please use the utility method 'drop_unknown_references' to clean the data." - ) - with pytest.raises(InvalidDataError, match=expected_message): - metadata.validate_data(data) - - cleaned_data = drop_unknown_references(data, metadata) - metadata.validate_data(cleaned_data) - captured = capsys.readouterr() - - # Assert - pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) - pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) - assert len(cleaned_data['child']) == 4 - expected_output = ( - 'Success! All foreign keys have referential integrity.\n\n' - 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' - ' child 5 1 4\n' - ' parent 5 0 5' - ) - assert captured.out.strip() == expected_output - - -def test_drop_unknown_references_valid_data(metadata, data, capsys): - """Test ``drop_unknown_references`` when data has referential integrity.""" - # Setup - data = deepcopy(data) - data['child'].loc[4, 'parent_id'] = 2 - - # Run - result = drop_unknown_references(data, metadata) - captured = capsys.readouterr() - - # Assert - pd.testing.assert_frame_equal(result['parent'], data['parent']) - pd.testing.assert_frame_equal(result['child'], data['child']) - expected_message = ( - 'Success! All foreign keys have referential integrity.\n\n' - 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' - ' child 5 0 5\n' - ' parent 5 0 5' - ) - assert captured.out.strip() == expected_message - - -def test_drop_unknown_references_drop_missing_values(metadata, data, capsys): - """Test ``drop_unknown_references`` when there is missing values in the foreign keys.""" - # Setup - data = deepcopy(data) - data['child'].loc[4, 'parent_id'] = np.nan - - # Run - cleaned_data = drop_unknown_references(data, metadata) - metadata.validate_data(cleaned_data) - captured = capsys.readouterr() - - # Assert - pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) - pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) - assert len(cleaned_data['child']) == 4 - expected_output = ( - 'Success! All foreign keys have referential integrity.\n\n' - 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' - ' child 5 1 4\n' - ' parent 5 0 5' - ) - assert captured.out.strip() == expected_output - - -def test_drop_unknown_references_not_drop_missing_values(metadata, data): - """Test ``drop_unknown_references`` when the missing values in the foreign keys are kept.""" - # Setup - data['child'].loc[3, 'parent_id'] = np.nan - - # Run - cleaned_data = drop_unknown_references( - data, metadata, drop_missing_values=False, verbose=False - ) - - # Assert - pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) - pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) - assert pd.isna(cleaned_data['child']['parent_id']).any() - assert len(cleaned_data['child']) == 4 - - def test_simplify_schema(capsys): """Test ``simplify_schema`` end to end.""" # Setup @@ -337,7 +244,7 @@ def test_get_random_subset_with_missing_values(metadata, data): data['child'].loc[4, 'parent_id'] = np.nan expected_warning = re.escape( 'The data contains null values in foreign key columns. ' - 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils' ' to drop these rows before using ``get_random_subset``.' ) diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py new file mode 100644 index 000000000..35b0f28d2 --- /dev/null +++ b/tests/integration/utils/test_utils.py @@ -0,0 +1,153 @@ +import re +from copy import deepcopy + +import numpy as np +import pandas as pd +import pytest + +from sdv.errors import InvalidDataError +from sdv.metadata import MultiTableMetadata +from sdv.utils import drop_unknown_references + + +@pytest.fixture +def metadata(): + return MultiTableMetadata.load_from_dict( + { + 'tables': { + 'parent': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'numerical'} + }, + 'primary_key': 'id' + }, + 'child': { + 'columns': { + 'parent_id': {'sdtype': 'id'}, + 'C': {'sdtype': 'categorical'} + } + } + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id', + 'child_foreign_key': 'parent_id' + } + ] + } + ) + + +@pytest.fixture +def data(): + parent = pd.DataFrame(data={ + 'id': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + 'B': [0.434, 0.312, 0.212, 0.339, 0.491] + }) + + child = pd.DataFrame(data={ + 'parent_id': [0, 1, 2, 2, 5], + 'C': ['Yes', 'No', 'Maye', 'No', 'No'] + }) + + return { + 'parent': parent, + 'child': child + } + + +def test_drop_unknown_references(metadata, data, capsys): + """Test ``drop_unknown_references`` end to end.""" + # Run + expected_message = re.escape( + 'The provided data does not match the metadata:\n' + 'Relationships:\n' + "Error: foreign key column 'parent_id' contains unknown references: (5)" + ". Please use the method 'drop_unknown_references' from sdv.utils to clean the data." + ) + with pytest.raises(InvalidDataError, match=expected_message): + metadata.validate_data(data) + + cleaned_data = drop_unknown_references(data, metadata) + metadata.validate_data(cleaned_data) + captured = capsys.readouterr() + + # Assert + pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) + pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) + assert len(cleaned_data['child']) == 4 + expected_output = ( + 'Success! All foreign keys have referential integrity.\n\n' + 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' + ' child 5 1 4\n' + ' parent 5 0 5' + ) + assert captured.out.strip() == expected_output + + +def test_drop_unknown_references_valid_data(metadata, data, capsys): + """Test ``drop_unknown_references`` when data has referential integrity.""" + # Setup + data = deepcopy(data) + data['child'].loc[4, 'parent_id'] = 2 + + # Run + result = drop_unknown_references(data, metadata) + captured = capsys.readouterr() + + # Assert + pd.testing.assert_frame_equal(result['parent'], data['parent']) + pd.testing.assert_frame_equal(result['child'], data['child']) + expected_message = ( + 'Success! All foreign keys have referential integrity.\n\n' + 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' + ' child 5 0 5\n' + ' parent 5 0 5' + ) + assert captured.out.strip() == expected_message + + +def test_drop_unknown_references_drop_missing_values(metadata, data, capsys): + """Test ``drop_unknown_references`` when there is missing values in the foreign keys.""" + # Setup + data = deepcopy(data) + data['child'].loc[4, 'parent_id'] = np.nan + + # Run + cleaned_data = drop_unknown_references(data, metadata) + metadata.validate_data(cleaned_data) + captured = capsys.readouterr() + + # Assert + pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) + pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) + assert len(cleaned_data['child']) == 4 + expected_output = ( + 'Success! All foreign keys have referential integrity.\n\n' + 'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n' + ' child 5 1 4\n' + ' parent 5 0 5' + ) + assert captured.out.strip() == expected_output + + +def test_drop_unknown_references_not_drop_missing_values(metadata, data): + """Test ``drop_unknown_references`` when the missing values in the foreign keys are kept.""" + # Setup + data['child'].loc[3, 'parent_id'] = np.nan + + # Run + cleaned_data = drop_unknown_references( + data, metadata, drop_missing_values=False, verbose=False + ) + + # Assert + pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent']) + pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4]) + assert pd.isna(cleaned_data['child']['parent_id']).any() + assert len(cleaned_data['child']) == 4 diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 34004325c..58d0b6975 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1289,9 +1289,9 @@ def test__validate_foreign_keys_missing_keys(self): 'Relationships:\n' "Error: foreign key column 'upravna_enota' contains unknown references: " '(10, 11, 12, 13, 14, + more). ' - "Please use the utility method 'drop_unknown_references' to clean the data.\n" + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data.\n" "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9)." - " Please use the utility method 'drop_unknown_references' to clean the data." + " Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ] assert result == missing_upravna_enota @@ -1404,7 +1404,7 @@ def test_validate_data_missing_foreign_keys(self): 'The provided data does not match the metadata:\n' 'Relationships:\n' "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " - "Please use the utility method 'drop_unknown_references' to clean the data." + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): metadata.validate_data(data) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 0cad058b0..fe7dc11ab 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -498,7 +498,7 @@ def test_validate_missing_foreign_keys(self): 'The provided data does not match the metadata:\n' 'Relationships:\n' "Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " - "Please use the utility method 'drop_unknown_references' to clean the data." + "Please use the method 'drop_unknown_references' from sdv.utils to clean the data." ) with pytest.raises(InvalidDataError, match=error_msg): instance.validate(data) diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index 089ef9c4e..1961d7c34 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -1,5 +1,4 @@ import re -from collections import defaultdict from unittest.mock import Mock, patch import numpy as np @@ -12,347 +11,23 @@ from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema -@patch('sdv.utils.poc._drop_rows') -def test_drop_unknown_references(mock_drop_rows): - """Test ``drop_unknown_references``.""" - # Setup - relationships = [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - }, - { - 'parent_table_name': 'child', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' - }, - { - 'parent_table_name': 'parent', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } - ] - - metadata = Mock() - metadata.relationships = relationships - metadata.tables = {'parent', 'child', 'grandchild'} - metadata.validate_data.side_effect = InvalidDataError('Invalid data') - - data = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 5], - 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] - }), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 6], - 'child_foreign_key': [9, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) - } - - def _drop_rows(data, metadata, drop_missing_values): - data['child'] = data['child'].iloc[:4] - data['grandchild'] = data['grandchild'].iloc[[1, 3]] - - mock_drop_rows.side_effect = _drop_rows - - # Run - result = drop_unknown_references(data, metadata) - - # Assert - expected_result = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2], - 'id_child': [5, 6, 7, 8], - 'B': ['Yes', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [1, 2], - 'child_foreign_key': [5, 6], - 'C': ['No', 'No'] - }, index=[1, 3]) - } - metadata.validate.assert_called_once() - metadata.validate_data.assert_called_once_with(data) - mock_drop_rows.assert_called_once_with(result, metadata, True) - for table_name, table in result.items(): - pd.testing.assert_frame_equal(table, expected_result[table_name]) - - -@patch('sys.stdout.write') -def test_drop_unknown_references_valid_data_mock(mock_stdout_write): - """Test ``drop_unknown_references`` when data has referential integrity.""" +@patch('sdv.utils.poc.utils_drop_unknown_references') +def test_drop_unknown_references(mock_drop_unknown_references): + """Test ``drop_unknown_references`` raise a FutureWarning when called from sdv.utils.poc.""" # Setup + data = Mock() metadata = Mock() - metadata._get_all_foreign_keys.side_effect = [ - [], ['parent_foreign_key'], ['child_foreign_key', 'parent_foreign_key'] - ] - metadata.tables = {'parent', 'child', 'grandchild'} - data = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 3], - 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] - }), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 3], - 'child_foreign_key': [6, 5, 7, 6, 9], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) - } - - # Run - result = drop_unknown_references(data, metadata) - - # Assert - expected_pattern = re.compile( - r'Success! All foreign keys have referential integrity\.\s*' - r'Table Name\s*#\s*Rows \(Original\)\s*#\s*Invalid Rows\s*#\s*Rows \(New\)\s*' - r'child\s*5\s*0\s*5\s*' - r'grandchild\s*5\s*0\s*5\s*' - r'parent\s*5\s*0\s*5' + expected_message = re.escape( + "Please access the 'drop_unknown_references' function directly from the sdv.utils module" + 'instead of sdv.utils.poc.' ) - output = mock_stdout_write.call_args[0][0] - assert expected_pattern.match(output) - metadata.validate.assert_called_once() - metadata.validate_data.assert_called_once_with(data) - for table_name, table in result.items(): - pd.testing.assert_frame_equal(table, data[table_name]) - - -@patch('sdv.multi_table.utils._get_rows_to_drop') -@patch('sdv.utils.poc._validate_foreign_keys_not_null') -def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_rows_to_drop): - """Test ``drop_unknown_references`` whith NaNs and drop_missing_values True.""" - # Setup - relationships = [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - }, - { - 'parent_table_name': 'child', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' - }, - { - 'parent_table_name': 'parent', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } - ] - - metadata = Mock() - metadata.relationships = relationships - metadata.tables = {'parent', 'child', 'grandchild'} - mock_validate_foreign_keys.side_effect = InvalidDataError('Invalid data') - - data = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 5, None], - 'id_child': [5, 6, 7, 8, 9, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 6, 4], - 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }) - } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 3, 4} - }) - - # Run - result = drop_unknown_references(data, metadata, verbose=False) - - # Assert - metadata.validate.assert_called_once() - metadata.validate_data.assert_called_once_with(data) - mock_validate_foreign_keys.assert_called_once_with(metadata, data) - mock_validate_foreign_keys.assert_called_once_with(metadata, data) - mock_get_rows_to_drop.assert_called_once() - expected_result = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0., 1., 2., 2.], - 'id_child': [5, 6, 7, 8], - 'B': ['Yes', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [2, 4], - 'child_foreign_key': [5., 4.], - 'C': ['No', 'No'] - }, index=[2, 5]) - } - for table_name, table in result.items(): - pd.testing.assert_frame_equal(table, expected_result[table_name]) - - -@patch('sdv.multi_table.utils._get_rows_to_drop') -def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop): - """Test ``drop_unknown_references`` with NaNs and drop_missing_values False.""" - # Setup - relationships = [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - }, - { - 'parent_table_name': 'child', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' - }, - { - 'parent_table_name': 'parent', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } - ] - - metadata = Mock() - metadata.relationships = relationships - metadata.tables = {'parent', 'child', 'grandchild'} - metadata.validate_data.side_effect = InvalidDataError('Invalid data') - - data = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 5, None], - 'id_child': [5, 6, 7, 8, 9, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 6, 4], - 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] - }) - } - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {4}, - 'grandchild': {0, 3, 4} - }) # Run - result = drop_unknown_references(data, metadata, drop_missing_values=False, verbose=False) + with pytest.warns(FutureWarning, match=expected_message): + drop_unknown_references(data, metadata) # Assert - mock_get_rows_to_drop.assert_called_once() - expected_result = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0., 1., 2., 2., None], - 'id_child': [5, 6, 7, 8, 10], - 'B': ['Yes', 'No', 'No', 'No', 'No'] - }, index=[0, 1, 2, 3, 5]), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [1, 2, 4], - 'child_foreign_key': [np.nan, 5, 4.], - 'C': ['No', 'No', 'No'] - }, index=[1, 2, 5]) - } - for table_name, table in result.items(): - pd.testing.assert_frame_equal(table, expected_result[table_name]) - - -@patch('sdv.multi_table.utils._get_rows_to_drop') -def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): - """Test ``drop_unknown_references`` when all rows are dropped.""" - # Setup - relationships = [ - { - 'parent_table_name': 'parent', - 'child_table_name': 'child', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - }, - { - 'parent_table_name': 'child', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_child', - 'child_foreign_key': 'child_foreign_key' - }, - { - 'parent_table_name': 'parent', - 'child_table_name': 'grandchild', - 'parent_primary_key': 'id_parent', - 'child_foreign_key': 'parent_foreign_key' - } - ] - - metadata = Mock() - metadata.relationships = relationships - metadata.tables = {'parent', 'child', 'grandchild'} - metadata.validate_data.side_effect = InvalidDataError('Invalid data') - - data = { - 'parent': pd.DataFrame({ - 'id_parent': [0, 1, 2, 3, 4], - 'A': [True, True, False, True, False], - }), - 'child': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 5], - 'id_child': [5, 6, 7, 8, 9], - 'B': ['Yes', 'No', 'No', 'No', 'No'] - }), - 'grandchild': pd.DataFrame({ - 'parent_foreign_key': [0, 1, 2, 2, 6], - 'child_foreign_key': [9, 5, 11, 6, 4], - 'C': ['Yes', 'No', 'No', 'No', 'No'] - }) - } - - mock_get_rows_to_drop.return_value = defaultdict(set, { - 'child': {0, 1, 2, 3, 4} - }) - - # Run and Assert - expected_message = re.escape( - 'The provided data does not match the metadata:\n' - "All references in table 'child' are unknown and must be dropped." - 'Try providing different data for this table.' - ) - with pytest.raises(InvalidDataError, match=expected_message): - drop_unknown_references(data, metadata) + mock_drop_unknown_references.assert_called_once_with(data, metadata) @patch('sdv.utils.poc._get_total_estimated_columns') diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py new file mode 100644 index 000000000..261eb6443 --- /dev/null +++ b/tests/unit/utils/test_utils.py @@ -0,0 +1,353 @@ +import re +from collections import defaultdict +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest + +from sdv.errors import InvalidDataError +from sdv.utils.utils import drop_unknown_references + + +@patch('sdv.utils.utils._drop_rows') +def test_drop_unknown_references(mock_drop_rows): + """Test ``drop_unknown_references``.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + metadata.validate_data.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5], + 'id_child': [5, 6, 7, 8, 9], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6], + 'child_foreign_key': [9, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + def _drop_rows(data, metadata, drop_missing_values): + data['child'] = data['child'].iloc[:4] + data['grandchild'] = data['grandchild'].iloc[[1, 3]] + + mock_drop_rows.side_effect = _drop_rows + + # Run + result = drop_unknown_references(data, metadata) + + # Assert + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [1, 2], + 'child_foreign_key': [5, 6], + 'C': ['No', 'No'] + }, index=[1, 3]) + } + metadata.validate.assert_called_once() + metadata.validate_data.assert_called_once_with(data) + mock_drop_rows.assert_called_once_with(result, metadata, True) + for table_name, table in result.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sys.stdout.write') +def test_drop_unknown_references_valid_data_mock(mock_stdout_write): + """Test ``drop_unknown_references`` when data has referential integrity.""" + # Setup + metadata = Mock() + metadata._get_all_foreign_keys.side_effect = [ + [], ['parent_foreign_key'], ['child_foreign_key', 'parent_foreign_key'] + ] + metadata.tables = {'parent', 'child', 'grandchild'} + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 3], + 'id_child': [5, 6, 7, 8, 9], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 3], + 'child_foreign_key': [6, 5, 7, 6, 9], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + # Run + result = drop_unknown_references(data, metadata) + + # Assert + expected_pattern = re.compile( + r'Success! All foreign keys have referential integrity\.\s*' + r'Table Name\s*#\s*Rows \(Original\)\s*#\s*Invalid Rows\s*#\s*Rows \(New\)\s*' + r'child\s*5\s*0\s*5\s*' + r'grandchild\s*5\s*0\s*5\s*' + r'parent\s*5\s*0\s*5' + ) + output = mock_stdout_write.call_args[0][0] + assert expected_pattern.match(output) + metadata.validate.assert_called_once() + metadata.validate_data.assert_called_once_with(data) + for table_name, table in result.items(): + pd.testing.assert_frame_equal(table, data[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +@patch('sdv.utils.utils._validate_foreign_keys_not_null') +def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_rows_to_drop): + """Test ``drop_unknown_references`` whith NaNs and drop_missing_values True.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + mock_validate_foreign_keys.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5, None], + 'id_child': [5, 6, 7, 8, 9, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6, 4], + 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }) + } + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {4}, + 'grandchild': {0, 3, 4} + }) + + # Run + result = drop_unknown_references(data, metadata, verbose=False) + + # Assert + metadata.validate.assert_called_once() + metadata.validate_data.assert_called_once_with(data) + mock_validate_foreign_keys.assert_called_once_with(metadata, data) + mock_validate_foreign_keys.assert_called_once_with(metadata, data) + mock_get_rows_to_drop.assert_called_once() + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0., 1., 2., 2.], + 'id_child': [5, 6, 7, 8], + 'B': ['Yes', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [2, 4], + 'child_foreign_key': [5., 4.], + 'C': ['No', 'No'] + }, index=[2, 5]) + } + for table_name, table in result.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop): + """Test ``drop_unknown_references`` with NaNs and drop_missing_values False.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + metadata.validate_data.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5, None], + 'id_child': [5, 6, 7, 8, 9, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6, 4], + 'child_foreign_key': [9, np.nan, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No', 'No'] + }) + } + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {4}, + 'grandchild': {0, 3, 4} + }) + + # Run + result = drop_unknown_references(data, metadata, drop_missing_values=False, verbose=False) + + # Assert + mock_get_rows_to_drop.assert_called_once() + expected_result = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0., 1., 2., 2., None], + 'id_child': [5, 6, 7, 8, 10], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }, index=[0, 1, 2, 3, 5]), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [1, 2, 4], + 'child_foreign_key': [np.nan, 5, 4.], + 'C': ['No', 'No', 'No'] + }, index=[1, 2, 5]) + } + for table_name, table in result.items(): + pd.testing.assert_frame_equal(table, expected_result[table_name]) + + +@patch('sdv.multi_table.utils._get_rows_to_drop') +def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop): + """Test ``drop_unknown_references`` when all rows are dropped.""" + # Setup + relationships = [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + }, + { + 'parent_table_name': 'child', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_child', + 'child_foreign_key': 'child_foreign_key' + }, + { + 'parent_table_name': 'parent', + 'child_table_name': 'grandchild', + 'parent_primary_key': 'id_parent', + 'child_foreign_key': 'parent_foreign_key' + } + ] + + metadata = Mock() + metadata.relationships = relationships + metadata.tables = {'parent', 'child', 'grandchild'} + metadata.validate_data.side_effect = InvalidDataError('Invalid data') + + data = { + 'parent': pd.DataFrame({ + 'id_parent': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + }), + 'child': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 5], + 'id_child': [5, 6, 7, 8, 9], + 'B': ['Yes', 'No', 'No', 'No', 'No'] + }), + 'grandchild': pd.DataFrame({ + 'parent_foreign_key': [0, 1, 2, 2, 6], + 'child_foreign_key': [9, 5, 11, 6, 4], + 'C': ['Yes', 'No', 'No', 'No', 'No'] + }) + } + + mock_get_rows_to_drop.return_value = defaultdict(set, { + 'child': {0, 1, 2, 3, 4} + }) + + # Run and Assert + expected_message = re.escape( + 'The provided data does not match the metadata:\n' + "All references in table 'child' are unknown and must be dropped." + 'Try providing different data for this table.' + ) + with pytest.raises(InvalidDataError, match=expected_message): + drop_unknown_references(data, metadata) From 60ae5559877a6a1a004df9e702bd3854ca4a1f9c Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Fri, 3 May 2024 00:02:33 +0200 Subject: [PATCH 14/32] Add ExcelHandler (#1962) --- pyproject.toml | 4 +- sdv/io/local/__init__.py | 5 +- sdv/io/local/local.py | 98 +++++++++- tests/integration/io/local/test_local.py | 73 +++++++- tests/unit/io/local/test_local.py | 219 ++++++++++++++++++++++- 5 files changed, 389 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index af250ec3b..3e89f63ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'deepecho>=0.6.0', 'rdt>=1.12.0', 'sdmetrics>=0.14.0', - 'platformdirs>=4.0' + 'platformdirs>=4.0', ] [project.urls] @@ -51,7 +51,9 @@ dependencies = [ sdv = { main = 'sdv.cli.__main__:main' } [project.optional-dependencies] +excel = ['pandas[excel]'] test = [ + 'sdv[excel]', 'pytest>=3.4.2', 'pytest-cov>=2.6.0', 'pytest-rerunfailures>=10.3,<15', diff --git a/sdv/io/local/__init__.py b/sdv/io/local/__init__.py index a233b25be..bd3c2ba5b 100644 --- a/sdv/io/local/__init__.py +++ b/sdv/io/local/__init__.py @@ -1,8 +1,9 @@ """Local I/O module.""" -from sdv.io.local.local import BaseLocalHandler, CSVHandler +from sdv.io.local.local import BaseLocalHandler, CSVHandler, ExcelHandler __all__ = ( 'BaseLocalHandler', - 'CSVHandler' + 'CSVHandler', + 'ExcelHandler' ) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index 0d81ab634..1ceaba195 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -33,7 +33,7 @@ def _infer_metadata(self, data): return metadata def read(self): - """Read data from files and returns it along with metadata. + """Read data from files and return it along with metadata. This method must be implemented by subclasses. @@ -91,7 +91,7 @@ def __init__(self, sep=',', encoding='UTF', decimal='.', float_format=None, self.quoting = quoting def read(self, folder_name, file_names=None): - """Read data from CSV files and returns it along with metadata. + """Read data from CSV files and return it along with metadata. Args: folder_name (str): @@ -192,3 +192,97 @@ def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'): quoting=self.quoting, mode=mode, ) + + +class ExcelHandler(BaseLocalHandler): + """A class for handling Excel files.""" + + def _read_excel(self, file_path, sheet_names=None): + """Read data from Excel File and return just the data as a dictionary.""" + data = {} + if sheet_names is None: + xl_file = pd.ExcelFile(file_path) + sheet_names = xl_file.sheet_names + + for sheet_name in sheet_names: + data[sheet_name] = pd.read_excel( + file_path, + sheet_name=sheet_name, + parse_dates=False, + decimal=self.decimal, + index_col=None + ) + + return data + + def read(self, file_path, sheet_names=None): + """Read data from Excel files and return it along with metadata. + + Args: + file_path (str): + The path to the Excel file to read. + sheet_names (list of str, optional): + The names of sheets to read. If None, all sheets are read. + + Returns: + tuple: + A tuple containing the data as a dictionary and metadata. The dictionary maps + table names to pandas DataFrames. The metadata is an object describing the data. + """ + metadata = MultiTableMetadata() + if sheet_names is not None and not isinstance(sheet_names, list): + raise ValueError("'sheet_names' must be None or a list of strings.") + + data = self._read_excel(file_path, sheet_names) + metadata = self._infer_metadata(data) + return data, metadata + + def write(self, synthetic_data, file_name, sheet_name_suffix=None, mode='w'): + """Write synthetic data to an Excel File. + + Args: + synthetic_data (dict): + A dictionary mapping table names to pandas DataFrames containing synthetic data. + file_name (str): + The name of the Excel file to write. + sheet_name_suffix (str, optional): + A suffix to add to each sheet name. + mode (str, optional): + The mode of writing to use. Defaults to 'w'. + 'w': Write sheets to a new file, clearing any existing file that may exist. + 'a': Append new sheets within the existing file. + Note: You cannot append data to existing sheets. + """ + temp_data = synthetic_data + suffix_added = False + + if mode == 'a': + temp_data = self._read_excel(file_name) + for table_name, table in synthetic_data.items(): + sheet_name = table_name + if sheet_name_suffix: + sheet_name = f'{table_name}{sheet_name_suffix}' + suffix_added = True + + if temp_data.get(sheet_name) is not None: + temp_data[sheet_name] = pd.concat( + [temp_data[sheet_name], synthetic_data[sheet_name]], + ignore_index=True + ) + + else: + temp_data[sheet_name] = table + + writer = pd.ExcelWriter(file_name) + for table_name, table_data in temp_data.items(): + if sheet_name_suffix and not suffix_added: + table_name += sheet_name_suffix + + table_data.to_excel( + writer, + sheet_name=table_name, + float_format=self.float_format, + index=False + ) + + writer.close() diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py index 87b3c80ea..934350deb 100644 --- a/tests/integration/io/local/test_local.py +++ b/tests/integration/io/local/test_local.py @@ -1,13 +1,13 @@ import pandas as pd -from sdv.io.local import CSVHandler +from sdv.io.local import CSVHandler, ExcelHandler from sdv.metadata import MultiTableMetadata class TestCSVHandler: - def test_integration_read_write(self, tmpdir): - """Test end to end the read and write methods of ``CSVHandler``.""" + def test_integration_write_and_read(self, tmpdir): + """Test end to end the write and read methods of ``CSVHandler``.""" # Prepare synthetic data synthetic_data = { 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), @@ -30,3 +30,70 @@ def test_integration_read_write(self, tmpdir): # Check if the dataframes match the original synthetic data pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2']) + + +class TestExcelHandler: + + def test_integration_write_and_read(self, tmpdir): + """Test end to end the write and read methods of ``ExcelHandler``.""" + # Prepare synthetic data + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + # Write synthetic data to xslx files + handler = ExcelHandler() + handler.write(synthetic_data, tmpdir / 'excel.xslx') + + # Read data from xslx file + data, metadata = handler.read(tmpdir / 'excel.xslx') + + # Check if data was read correctly + assert len(data) == 2 + assert 'table1' in data + assert 'table2' in data + assert isinstance(metadata, MultiTableMetadata) is True + + # Check if the dataframes match the original synthetic data + pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) + pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2']) + + def test_integration_write_and_read_append_mode(self, tmpdir): + """Test end to end the write and read methods of ``ExcelHandler``.""" + # Prepare synthetic data + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + # Write synthetic data to xslx files + handler = ExcelHandler() + handler.write(synthetic_data, tmpdir / 'excel.xslx') + + # Read data from xslx file + data, metadata = handler.read(tmpdir / 'excel.xslx') + + # Write using append mode + handler.write(synthetic_data, tmpdir / 'excel.xslx', mode='a') + + # Read data from xslx file + data, metadata = handler.read(tmpdir / 'excel.xslx') + + # Check if data was read correctly + assert len(data) == 2 + assert 'table1' in data + assert 'table2' in data + assert isinstance(metadata, MultiTableMetadata) is True + + # Check if the dataframes match the original synthetic data + expected_table_one = pd.concat( + [synthetic_data['table1'], synthetic_data['table1']], + ignore_index=True + ) + expected_table_two = pd.concat( + [synthetic_data['table2'], synthetic_data['table2']], + ignore_index=True + ) + pd.testing.assert_frame_equal(data['table1'], expected_table_one) + pd.testing.assert_frame_equal(data['table2'], expected_table_two) diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py index e69d18636..363754da5 100644 --- a/tests/unit/io/local/test_local.py +++ b/tests/unit/io/local/test_local.py @@ -1,12 +1,12 @@ """Unit tests for local file handlers.""" import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, call, patch import pandas as pd import pytest -from sdv.io.local.local import CSVHandler +from sdv.io.local.local import CSVHandler, ExcelHandler from sdv.metadata.multi_table import MultiTableMetadata @@ -210,3 +210,218 @@ def test_write_file_exists_mode_is_w(self, tmpdir): 'col2': ['a', 'b', 'c'] }) pd.testing.assert_frame_equal(dataframe, expected_dataframe) + + +class TestExcelHandler: + + def test___init__(self): + """Test the init parameters with default values.""" + # Run + instance = ExcelHandler() + + # Assert + assert instance.decimal == '.' + assert instance.float_format is None + + def test___init___custom(self): + """Test custom initialization of the class.""" + # Run + instance = ExcelHandler(decimal=',', float_format='%.2f') + + # Assert + assert instance.decimal == ',' + assert instance.float_format == '%.2f' + + @patch('sdv.io.local.local.pd') + def test_read(self, mock_pd): + """Test the read method of ExcelHandler class.""" + # Setup + file_path = 'test_file.xlsx' + mock_pd.ExcelFile.return_value = Mock(sheet_names=['Sheet1', 'Sheet2']) + mock_pd.read_excel.side_effect = [ + pd.DataFrame({'A': [1, 2], 'B': [3, 4]}), + pd.DataFrame({'C': [5, 6], 'D': [7, 8]}) + ] + + instance = ExcelHandler() + + # Run + data, metadata = instance.read(file_path) + + # Assert + sheet_1_call = call( + 'test_file.xlsx', + sheet_name='Sheet1', + parse_dates=False, + decimal='.', + index_col=None + ) + sheet_2_call = call( + 'test_file.xlsx', + sheet_name='Sheet2', + parse_dates=False, + decimal='.', + index_col=None + ) + pd.testing.assert_frame_equal( + data['Sheet1'], + pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + ) + pd.testing.assert_frame_equal( + data['Sheet2'], + pd.DataFrame({'C': [5, 6], 'D': [7, 8]}) + ) + assert isinstance(metadata, MultiTableMetadata) + assert mock_pd.read_excel.call_args_list == [sheet_1_call, sheet_2_call] + + @patch('sdv.io.local.local.pd') + def test_read_sheet_names(self, mock_pd): + """Test the read method when provided sheet names.""" + # Setup + file_path = 'test_file.xlsx' + sheet_names = ['Sheet1'] + mock_pd.ExcelFile.return_value = Mock(sheet_names=['Sheet1', 'Sheet2']) + mock_pd.read_excel.side_effect = [ + pd.DataFrame({'A': [1, 2], 'B': [3, 4]}), + pd.DataFrame({'C': [5, 6], 'D': [7, 8]}) + ] + + instance = ExcelHandler() + + # Run + data, metadata = instance.read(file_path, sheet_names) + + # Assert + sheet_1_call = call( + 'test_file.xlsx', + sheet_name='Sheet1', + parse_dates=False, + decimal='.', + index_col=None + ) + pd.testing.assert_frame_equal( + data['Sheet1'], + pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + ) + assert isinstance(metadata, MultiTableMetadata) + assert mock_pd.read_excel.call_args_list == [sheet_1_call] + assert list(data) == ['Sheet1'] + + def test_read_sheet_names_string(self): + """Test the read method when provided sheet names but they are string.""" + # Setup + file_path = 'test_file.xlsx' + sheet_names = 'Sheet1' + instance = ExcelHandler() + + # Run and Assert + error_msg = "'sheet_names' must be None or a list of strings." + with pytest.raises(ValueError, match=error_msg): + instance.read(file_path, sheet_names) + + @patch('sdv.io.local.local.pd') + def test_write(self, mock_pd): + """Test the write functionality of the ExcelHandler.""" + # Setup + sheet_one = Mock() + sheet_two = Mock() + synthetic_data = {'Sheet1': sheet_one, 'Sheet2': sheet_two} + + file_name = 'output_file.xlsx' + sheet_name_suffix = '_synthetic' + instance = ExcelHandler() + + # Run + instance.write(synthetic_data, file_name, sheet_name_suffix) + + # Assert + sheet_one.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet1_synthetic', + float_format=None, + index=False + ) + sheet_two.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet2_synthetic', + float_format=None, + index=False + ) + mock_pd.ExcelWriter.return_value.close.assert_called_once_with() + + @patch('sdv.io.local.local.pd') + def test_write_mode_append(self, mock_pd): + """Test the write functionality of the ExcelHandler when mode is `a``.""" + # Setup + sheet_one = Mock() + sheet_two = Mock() + synth_sheet_one = Mock() + synth_sheet_two = Mock() + synthetic_data = {'Sheet1': synth_sheet_one, 'Sheet2': synth_sheet_two} + + file_name = 'output_file.xlsx' + sheet_name_suffix = '_synthetic' + instance = ExcelHandler() + instance._read_excel = Mock(return_value={'Sheet1': sheet_one, 'Sheet2': sheet_two}) + + # Run + instance.write(synthetic_data, file_name, sheet_name_suffix, mode='a') + + # Assert + sheet_one.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet1', + float_format=None, + index=False + ) + sheet_two.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet2', + float_format=None, + index=False + ) + synth_sheet_one.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet1_synthetic', + float_format=None, + index=False + ) + synth_sheet_two.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet2_synthetic', + float_format=None, + index=False + ) + mock_pd.ExcelWriter.return_value.close.assert_called_once_with() + + @patch('sdv.io.local.local.pd') + def test_write_mode_append_no_suffix(self, mock_pd): + """Test the write functionality of the ExcelHandler when mode is `a`` and no suffix.""" + # Setup + sheet_one = Mock() + sheet_two = Mock() + synth_sheet_one = Mock() + synthetic_data = {'Sheet1': synth_sheet_one} + file_name = 'output_file.xlsx' + instance = ExcelHandler() + instance._read_excel = Mock(return_value={'Sheet1': sheet_one, 'Sheet2': sheet_two}) + + # Run + instance.write(synthetic_data, file_name, mode='a') + + # Assert + mock_pd.concat.assert_called_once_with([sheet_one, synth_sheet_one], ignore_index=True) + mock_pd.concat.return_value.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet1', + float_format=None, + index=False + ) + + sheet_two.to_excel.assert_called_once_with( + mock_pd.ExcelWriter.return_value, + sheet_name='Sheet2', + float_format=None, + index=False + ) + mock_pd.ExcelWriter.return_value.close.assert_called_once_with() From 46ef1c755730e1ae1fa08feb0c670298e6b697fa Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Fri, 3 May 2024 13:33:32 -0500 Subject: [PATCH 15/32] Fixing test assertion for updated metadata (#1982) --- tests/integration/data_processing/test_data_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index 00f933f39..358415625 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -6,7 +6,7 @@ import pandas as pd import pytest from rdt.transformers import ( - AnonymizedFaker, BinaryEncoder, FloatFormatter, IDGenerator, RegexGenerator, UniformEncoder, + AnonymizedFaker, BinaryEncoder, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder) from sdv._utils import _get_datetime_format @@ -252,7 +252,7 @@ def test_prepare_for_fitting(self): 'mba_spec': UniformEncoder, 'employability_perc': FloatFormatter, 'placed': UniformEncoder, - 'student_id': RegexGenerator, + 'student_id': AnonymizedFaker, 'experience_years': FloatFormatter, 'duration': UniformEncoder, 'salary': FloatFormatter, From 248bc1a4c96c11edbdac97105e082ebda820ae39 Mon Sep 17 00:00:00 2001 From: SDV Team <98988753+sdv-team@users.noreply.github.com> Date: Mon, 6 May 2024 09:14:40 -0400 Subject: [PATCH 16/32] Automated Latest Dependency Updates (#1986) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- latest_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latest_requirements.txt b/latest_requirements.txt index 9da60a07f..3c64a8cdd 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -8,4 +8,4 @@ pandas==2.2.2 platformdirs==4.2.1 rdt==1.12.0 sdmetrics==0.14.0 -tqdm==4.66.2 +tqdm==4.66.4 From 364b40a74f1c6a1144b29474c402a0a922cfa286 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Tue, 7 May 2024 17:07:44 +0100 Subject: [PATCH 17/32] Improve error message when trying to sample before fitting (single table) (#1992) --- sdv/single_table/base.py | 9 ++++++++- tests/integration/single_table/test_base.py | 18 +++++++++++++++++- tests/unit/single_table/test_base.py | 16 +++++++++++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 60656a4a2..0b6fa521b 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -23,7 +23,8 @@ _groupby_list, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor -from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.errors import ( + ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) from sdv.logging.utils import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path @@ -871,6 +872,12 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file pandas.DataFrame: Sampled data. """ + if not self._fitted: + raise SamplingError( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + sample_timestamp = datetime.datetime.now() has_constraints = bool(self._data_processor._constraints) has_batches = batch_size is not None and batch_size != num_rows diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 8c7ea2601..d0c105987 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -13,7 +13,7 @@ from sdv import version from sdv.datasets.demo import download_demo -from sdv.errors import SynthesizerInputError, VersionError +from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.metadata import SingleTableMetadata from sdv.sampling import Condition from sdv.single_table import ( @@ -855,3 +855,19 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ' Total number of columns: 3\n' ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' ) + + +@pytest.mark.parametrize('synthesizer', SYNTHESIZERS) +def test_sample_not_fitted(synthesizer): + """Test that a synthesizer raises an error when trying to sample without fitting.""" + # Setup + metadata = SingleTableMetadata() + synthesizer = synthesizer.__class__(metadata) + expected_message = re.escape( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + + # Run and Assert + with pytest.raises(SamplingError, match=expected_message): + synthesizer.sample(10) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 197141e69..932fa52da 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -12,7 +12,7 @@ from sdv import version from sdv.constraints.errors import AggregateConstraintsError -from sdv.errors import ConstraintsNotMetError, SynthesizerInputError, VersionError +from sdv.errors import ConstraintsNotMetError, SamplingError, SynthesizerInputError, VersionError from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling.tabular import Condition from sdv.single_table import ( @@ -1399,6 +1399,20 @@ def test__sample_with_progress_bar_removing_temp_file( mock_os.remove.assert_called_once_with('.sample.csv.temp') mock_os.path.exists.assert_called_once_with('.sample.csv.temp') + def test_sample_not_fitted(self): + """Test that ``sample`` raises an error when the synthesizer is not fitted.""" + # Setup + instance = Mock() + instance._fitted = False + expected_message = re.escape( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + + # Run and Assert + with pytest.raises(SamplingError, match=expected_message): + BaseSingleTableSynthesizer.sample(instance, 10) + @patch('sdv.single_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that we use ``_sample_with_progress_bar`` in this method.""" From 87ceead40567a859d4a040c5da9aa863947cf1b3 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Tue, 7 May 2024 18:42:22 +0200 Subject: [PATCH 18/32] Improve error message when trying to sample before fitting (MultiTableSynthesizers) (#1993) --- sdv/multi_table/base.py | 9 ++++++++- tests/integration/multi_table/test_hma.py | 17 ++++++++++++++++- tests/unit/multi_table/test_base.py | 19 ++++++++++++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 114f40739..b9e705947 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -15,7 +15,8 @@ from sdv._utils import ( _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) -from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.errors import ( + ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -467,6 +468,12 @@ def sample(self, scale=1.0): If ``scale`` is lower than ``1.0`` create fewer rows by the factor of ``scale`` than the original tables. Defaults to ``1.0``. """ + if not self._fitted: + raise SamplingError( + 'This synthesizer has not been fitted. Please fit your synthesizer first before ' + 'sampling synthetic data.' + ) + if not type(scale) in (float, int) or not scale > 0: raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index f7c499d28..33c728ec9 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -16,7 +16,7 @@ from sdv import version from sdv.datasets.demo import download_demo from sdv.datasets.local import load_csvs -from sdv.errors import SynthesizerInputError, VersionError +from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table import HMASynthesizer @@ -1669,6 +1669,21 @@ def test_hma_relationship_validity(): assert report.get_details('Relationship Validity')['Score'].mean() == 1.0 +def test_hma_not_fit_raises_sampling_error(): + """Test that ``HMA`` will raise a ``SamplingError`` if it wasn't fit.""" + # Setup + data, metadata = download_demo('multi_table', 'Dunur_v1') + synthesizer = HMASynthesizer(metadata) + + # Run and Assert + error_msg = ( + 'This synthesizer has not been fitted. Please fit your synthesizer first before ' + 'sampling synthetic data.' + ) + with pytest.raises(SamplingError, match=error_msg): + synthesizer.sample(1) + + @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.datetime') def test_synthesizer_logger(mock_datetime, mock_generate_id): diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index fe7dc11ab..9f5548330 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -11,7 +11,8 @@ from sdv import version from sdv.errors import ( - ConstraintsNotMetError, InvalidDataError, NotFittedError, SynthesizerInputError, VersionError) + ConstraintsNotMetError, InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError, + VersionError) from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -1016,6 +1017,7 @@ def test_sample_validate_input(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) + instance._fitted = True instance._sample = Mock() scales = ['Test', True, -1.2, np.nan] @@ -1038,6 +1040,20 @@ def test_sample_validate_input(self): with pytest.raises(SynthesizerInputError, match=msg): instance.sample(scale=scale) + def test_sample_raises_sampling_error(self): + """Test that ``sample`` will raise ``SamplingError`` when not fitted.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + + # Run and Assert + error_msg = ( + 'This synthesizer has not been fitted. Please fit your synthesizer first before ' + 'sampling synthetic data.' + ) + with pytest.raises(SamplingError, match=error_msg): + instance.sample(1) + @patch('sdv.multi_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that ``sample`` calls the ``_sample`` with the given arguments.""" @@ -1045,6 +1061,7 @@ def test_sample(self, mock_datetime, caplog): mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) + instance._fitted = True data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) From 270d5e597015795efb7d1cc12911390ca9248214 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Tue, 7 May 2024 20:43:41 +0200 Subject: [PATCH 19/32] Split out metadata creation from data (#1988) --- sdv/io/local/local.py | 27 ++++------ setup.cfg | 4 +- tests/integration/io/local/test_local.py | 17 +++++-- tests/unit/io/local/test_local.py | 64 +++++++++++++++++++++--- 4 files changed, 81 insertions(+), 31 deletions(-) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index 1ceaba195..830cf4512 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -16,7 +16,7 @@ def __init__(self, decimal='.', float_format=None): self.decimal = decimal self.float_format = float_format - def _infer_metadata(self, data): + def create_metadata(self, data): """Detect the metadata for all tables in a dictionary of dataframes. Args: @@ -38,9 +38,8 @@ def read(self): This method must be implemented by subclasses. Returns: - tuple: - A tuple containing the read data as a dictionary and metadata. The dictionary maps - table names to pandas DataFrames. The metadata is an object describing the data. + dict: + The dictionary maps table names to pandas DataFrames. """ raise NotImplementedError() @@ -101,17 +100,14 @@ def read(self, folder_name, file_names=None): in the folder are read. Returns: - tuple: - A tuple containing the data as a dictionary and metadata. The dictionary maps - table names to pandas DataFrames. The metadata is an object describing the data. + dict: + The dictionary maps table names to pandas DataFrames. Raises: FileNotFoundError: If the specified files do not exist in the folder. """ data = {} - metadata = MultiTableMetadata() - folder_path = Path(folder_name) if file_names is None: @@ -156,8 +152,7 @@ def read(self, folder_name, file_names=None): **kwargs ) - metadata = self._infer_metadata(data) - return data, metadata + return data def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'): """Write synthetic data to CSV files. @@ -225,17 +220,13 @@ def read(self, file_path, sheet_names=None): The names of sheets to read. If None, all sheets are read. Returns: - tuple: - A tuple containing the data as a dictionary and metadata. The dictionary maps - table names to pandas DataFrames. The metadata is an object describing the data. + dict: + The dictionary maps table names to pandas DataFrames. """ - metadata = MultiTableMetadata() if sheet_names is not None and not isinstance(sheet_names, list): raise ValueError("'sheet_names' must be None or a list of strings.") - data = self._read_excel(file_path, sheet_names) - metadata = self._infer_metadata(data) - return data, metadata + return self._read_excel(file_path, sheet_names) def write(self, synthetic_data, file_name, sheet_name_suffix=None, mode='w'): """Write synthetic data to an Excel File. diff --git a/setup.cfg b/setup.cfg index 1659a6e09..c2b017843 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,9 @@ extend-ignore = # TokenError: unterminated string literal E902, # Mutable default arg of type List - M511 + M511, + # Logging and IO shadowing python's builtins + A005 [aliases] test = pytest diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py index 934350deb..c133e1373 100644 --- a/tests/integration/io/local/test_local.py +++ b/tests/integration/io/local/test_local.py @@ -19,7 +19,10 @@ def test_integration_write_and_read(self, tmpdir): handler.write(synthetic_data, tmpdir) # Read data from CSV files - data, metadata = handler.read(tmpdir) + data = handler.read(tmpdir) + + # Detect metadata + metadata = handler.create_metadata(data) # Check if data was read correctly assert len(data) == 2 @@ -47,7 +50,10 @@ def test_integration_write_and_read(self, tmpdir): handler.write(synthetic_data, tmpdir / 'excel.xslx') # Read data from xslx file - data, metadata = handler.read(tmpdir / 'excel.xslx') + data = handler.read(tmpdir / 'excel.xslx') + + # Detect metadata + metadata = handler.create_metadata(data) # Check if data was read correctly assert len(data) == 2 @@ -72,13 +78,16 @@ def test_integration_write_and_read_append_mode(self, tmpdir): handler.write(synthetic_data, tmpdir / 'excel.xslx') # Read data from xslx file - data, metadata = handler.read(tmpdir / 'excel.xslx') + data = handler.read(tmpdir / 'excel.xslx') # Write using append mode handler.write(synthetic_data, tmpdir / 'excel.xslx', mode='a') # Read data from xslx file - data, metadata = handler.read(tmpdir / 'excel.xslx') + data = handler.read(tmpdir / 'excel.xslx') + + # Detect metadata + metadata = handler.create_metadata(data) # Check if data was read correctly assert len(data) == 2 diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py index 363754da5..e6cbcc119 100644 --- a/tests/unit/io/local/test_local.py +++ b/tests/unit/io/local/test_local.py @@ -6,10 +6,61 @@ import pandas as pd import pytest -from sdv.io.local.local import CSVHandler, ExcelHandler +from sdv.io.local.local import BaseLocalHandler, CSVHandler, ExcelHandler from sdv.metadata.multi_table import MultiTableMetadata +class TestBaseLocalHandler: + + def test___init__(self): + """Test the default initialization of the class.""" + # Run + instance = BaseLocalHandler() + + # Assert + assert instance.decimal == '.' + assert instance.float_format is None + + def test_create_metadata(self): + """Test that ``create_metadata`` will infer the metadata.""" + # Setup + data = { + 'hotel': pd.DataFrame({ + 'hotel_id': [1, 2, 3, 4, 5], + 'stars': [3, 4, 5, 3, 4] + }), + 'guests': pd.DataFrame({ + 'guest_id': [1, 2, 3, 4, 5], + 'hotel_id': [1, 1, 3, 2, 3] + }) + } + instance = BaseLocalHandler() + + # Run + metadata = instance.create_metadata(data) + + # Assert + assert isinstance(metadata, MultiTableMetadata) + assert metadata.to_dict() == { + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + 'relationships': [], + 'tables': { + 'guests': { + 'columns': { + 'guest_id': {'sdtype': 'numerical'}, + 'hotel_id': {'sdtype': 'numerical'} + } + }, + 'hotel': { + 'columns': { + 'hotel_id': {'sdtype': 'numerical'}, + 'stars': {'sdtype': 'numerical'} + } + } + } + } + + class TestCSVHandler: def test___init__(self): @@ -69,13 +120,12 @@ def test_read(self, mock_read_csv, mock_glob): handler = CSVHandler() # Run - data, metadata = handler.read('/path/to/data') + data = handler.read('/path/to/data') # Assert assert len(data) == 2 assert 'parent' in data assert 'child' in data - assert isinstance(metadata, MultiTableMetadata) assert mock_read_csv.call_count == 2 pd.testing.assert_frame_equal( data['parent'], @@ -102,7 +152,7 @@ def test_read_files(self, tmpdir): handler = CSVHandler() # Run - data, metadata = handler.read(tmpdir, file_names=['parent.csv']) + data = handler.read(tmpdir, file_names=['parent.csv']) # Assert assert 'parent' in data @@ -246,7 +296,7 @@ def test_read(self, mock_pd): instance = ExcelHandler() # Run - data, metadata = instance.read(file_path) + data = instance.read(file_path) # Assert sheet_1_call = call( @@ -271,7 +321,6 @@ def test_read(self, mock_pd): data['Sheet2'], pd.DataFrame({'C': [5, 6], 'D': [7, 8]}) ) - assert isinstance(metadata, MultiTableMetadata) assert mock_pd.read_excel.call_args_list == [sheet_1_call, sheet_2_call] @patch('sdv.io.local.local.pd') @@ -289,7 +338,7 @@ def test_read_sheet_names(self, mock_pd): instance = ExcelHandler() # Run - data, metadata = instance.read(file_path, sheet_names) + data = instance.read(file_path, sheet_names) # Assert sheet_1_call = call( @@ -303,7 +352,6 @@ def test_read_sheet_names(self, mock_pd): data['Sheet1'], pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) ) - assert isinstance(metadata, MultiTableMetadata) assert mock_pd.read_excel.call_args_list == [sheet_1_call] assert list(data) == ['Sheet1'] From aa5188db39f2eeec3ba80fc17e63f5f67dccc1df Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Wed, 8 May 2024 16:04:25 +0200 Subject: [PATCH 20/32] Update the logger settings (#1981) --- sdv/logging/__init__.py | 3 +- sdv/logging/logger.py | 62 ++++++++++++++++ sdv/logging/sdv_logger_config.yml | 2 +- sdv/logging/utils.py | 69 +++--------------- sdv/single_table/base.py | 2 +- tests/integration/multi_table/test_hma.py | 74 ------------------- tests/integration/single_table/test_base.py | 78 --------------------- tests/unit/logging/test_logger.py | 34 +++++++++ tests/unit/logging/test_utils.py | 31 +------- 9 files changed, 112 insertions(+), 243 deletions(-) create mode 100644 sdv/logging/logger.py create mode 100644 tests/unit/logging/test_logger.py diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index 436a1a442..c15348231 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -1,6 +1,7 @@ """Module for configuring loggers within the SDV library.""" -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config +from sdv.logging.logger import get_sdv_logger +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger_config __all__ = ( 'disable_single_table_logger', diff --git a/sdv/logging/logger.py b/sdv/logging/logger.py new file mode 100644 index 000000000..52a178638 --- /dev/null +++ b/sdv/logging/logger.py @@ -0,0 +1,62 @@ +"""SDV Logger.""" + +import logging +from functools import lru_cache + +from sdv.logging.utils import get_sdv_logger_config + + +@lru_cache() +def get_sdv_logger(logger_name): + """Get a logger instance with the specified name and configuration. + + This function retrieves or creates a logger instance with the specified name + and applies configuration settings based on the logger's name and the logging + configuration. + + Args: + logger_name (str): + The name of the logger to retrieve or create. + + Returns: + logging.Logger: + A logger instance configured according to the logging configuration + and the specific settings for the given logger name. + """ + logger_conf = get_sdv_logger_config() + logger = logging.getLogger(logger_name) + if logger_conf.get('log_registry') is None: + # Return a logger without any extra settings and avoid writing into files or other streams + return logger + + if logger_conf.get('log_registry') == 'local': + for handler in logger.handlers: + # Remove handlers that could exist previously + logger.removeHandler(handler) + + if logger_name in logger_conf.get('loggers'): + formatter = None + config = logger_conf.get('loggers').get(logger_name) + log_level = getattr(logging, config.get('level', 'INFO')) + if config.get('format'): + formatter = logging.Formatter(config.get('format')) + + logger.setLevel(log_level) + logger.propagate = config.get('propagate', False) + handler = config.get('handlers') + handlers = handler.get('class') + handlers = [handlers] if isinstance(handlers, str) else handlers + for handler_class in handlers: + if handler_class == 'logging.FileHandler': + logfile = handler.get('filename') + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): + ch = logging.StreamHandler() + ch.setLevel(log_level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml index 64104495f..4b01b0c65 100644 --- a/sdv/logging/sdv_logger_config.yml +++ b/sdv/logging/sdv_logger_config.yml @@ -1,4 +1,4 @@ -log_registry: 'local' +log_registry: null version: 1 loggers: SingleTableSynthesizer: diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 2f6a13be4..471870649 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -2,7 +2,7 @@ import contextlib import logging -from functools import lru_cache +import shutil from pathlib import Path import platformdirs @@ -11,13 +11,18 @@ def get_sdv_logger_config(): """Return a dictionary with the logging configuration.""" - logging_path = Path(__file__).parent - with open(logging_path / 'sdv_logger_config.yml', 'r') as f: - logger_conf = yaml.safe_load(f) - - # Logfile to be in this same directory store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) store_path.mkdir(parents=True, exist_ok=True) + config_path = Path(__file__).parent / 'sdv_logger_config.yml' + + if (store_path / 'sdv_logger_config.yml').exists(): + config_path = store_path / 'sdv_logger_config.yml' + else: + shutil.copyfile(config_path, store_path / 'sdv_logger_config.yml') + + with open(config_path, 'r') as f: + logger_conf = yaml.safe_load(f) + for logger in logger_conf.get('loggers', {}).values(): handler = logger.get('handlers', {}) if handler.get('filename') == 'sdv_logs.log': @@ -44,55 +49,3 @@ def disable_single_table_logger(): finally: for handler in handlers: single_table_logger.addHandler(handler) - - -@lru_cache() -def get_sdv_logger(logger_name): - """Get a logger instance with the specified name and configuration. - - This function retrieves or creates a logger instance with the specified name - and applies configuration settings based on the logger's name and the logging - configuration. - - Args: - logger_name (str): - The name of the logger to retrieve or create. - - Returns: - logging.Logger: - A logger instance configured according to the logging configuration - and the specific settings for the given logger name. - """ - logger_conf = get_sdv_logger_config() - if logger_conf.get('log_registry') is None: - # Return a logger without any extra settings and avoid writing into files or other streams - return logging.getLogger(logger_name) - - if logger_conf.get('log_registry') == 'local': - logger = logging.getLogger(logger_name) - if logger_name in logger_conf.get('loggers'): - formatter = None - config = logger_conf.get('loggers').get(logger_name) - log_level = getattr(logging, config.get('level', 'INFO')) - if config.get('format'): - formatter = logging.Formatter(config.get('format')) - - logger.setLevel(log_level) - logger.propagate = config.get('propagate', False) - handler = config.get('handlers') - handlers = handler.get('class') - handlers = [handlers] if isinstance(handlers, str) else handlers - for handler_class in handlers: - if handler_class == 'logging.FileHandler': - logfile = handler.get('filename') - file_handler = logging.FileHandler(logfile) - file_handler.setLevel(log_level) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): - ch = logging.StreamHandler() - ch.setLevel(log_level) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return logger diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 0b6fa521b..df3cc448e 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -25,7 +25,7 @@ from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ( ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) -from sdv.logging.utils import get_sdv_logger +from sdv.logging import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path LOGGER = logging.getLogger(__name__) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 33c728ec9..7a7d21ddb 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2,12 +2,9 @@ import importlib.metadata import re import warnings -from pathlib import Path -from unittest.mock import patch import numpy as np import pandas as pd -import platformdirs import pytest from faker import Faker from rdt.transformers import FloatFormatter @@ -1682,74 +1679,3 @@ def test_hma_not_fit_raises_sampling_error(): ) with pytest.raises(SamplingError, match=error_msg): synthesizer.sample(1) - - -@patch('sdv.multi_table.base.generate_synthesizer_id') -@patch('sdv.multi_table.base.datetime') -def test_synthesizer_logger(mock_datetime, mock_generate_id): - """Test that the synthesizer logger logs the expected messages.""" - # Setup - store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) - file_name = 'sdv_logs.log' - - synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - mock_generate_id.return_value = synth_id - mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' - data, metadata = download_demo('multi_table', 'fake_hotels') - - # Run - instance = HMASynthesizer(metadata) - - # Assert - with open(store_path / file_name) as f: - instance_lines = f.readlines()[-4:] - - assert ''.join(instance_lines) == ( - 'Instance:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: HMASynthesizer\n' - ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.fit(data) - - # Assert - with open(store_path / file_name) as f: - fit_lines = f.readlines()[-17:] - - assert ''.join(fit_lines) == ( - 'Fit:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: HMASynthesizer\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 668\n' - ' Total number of columns: 15\n' - ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - '\nFit processed data:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: HMASynthesizer\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 668\n' - ' Total number of columns: 11\n' - ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.sample(1) - with open(store_path / file_name) as f: - sample_lines = f.readlines()[-8:] - - # Assert - assert ''.join(sample_lines) == ( - 'Sample:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: HMASynthesizer\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 668\n' - ' Total number of columns: 15\n' - ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index d0c105987..19fe6207a 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -2,12 +2,10 @@ import importlib.metadata import re import warnings -from pathlib import Path from unittest.mock import patch import numpy as np import pandas as pd -import platformdirs import pytest from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder @@ -781,82 +779,6 @@ def test_fit_raises_version_error(): instance.fit(data) -@patch('sdv.single_table.base.generate_synthesizer_id') -@patch('sdv.single_table.base.datetime') -def test_synthesizer_logger(mock_datetime, mock_generate_id): - """Test that the synthesizer logger logs the expected messages.""" - # Setup - store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) - file_name = 'sdv_logs.log' - - synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - mock_generate_id.return_value = synth_id - mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' - data = pd.DataFrame({ - 'col 1': [1, 2, 3], - 'col 2': [4, 5, 6], - 'col 3': ['a', 'b', 'c'], - }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - - # Run - instance = GaussianCopulaSynthesizer(metadata) - - # Assert - with open(store_path / file_name) as f: - instance_lines = f.readlines()[-4:] - - assert ''.join(instance_lines) == ( - 'Instance:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.fit(data) - - # Assert - with open(store_path / file_name) as f: - fit_lines = f.readlines()[-17:] - - assert ''.join(fit_lines) == ( - 'Fit:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - '\nFit processed data:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.sample(100) - with open(store_path / file_name) as f: - sample_lines = f.readlines()[-8:] - - assert ''.join(sample_lines) == ( - 'Sample:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 100\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - @pytest.mark.parametrize('synthesizer', SYNTHESIZERS) def test_sample_not_fitted(synthesizer): """Test that a synthesizer raises an error when trying to sample without fitting.""" diff --git a/tests/unit/logging/test_logger.py b/tests/unit/logging/test_logger.py new file mode 100644 index 000000000..4770d39cd --- /dev/null +++ b/tests/unit/logging/test_logger.py @@ -0,0 +1,34 @@ +"""Test ``SDV`` logger.""" +import logging +from unittest.mock import Mock, patch + +from sdv.logging.logger import get_sdv_logger + + +@patch('sdv.logging.logger.logging.StreamHandler') +@patch('sdv.logging.logger.logging.getLogger') +@patch('sdv.logging.logger.get_sdv_logger_config') +def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler): + # Setup + mock_logger_conf = { + 'log_registry': 'local', + 'loggers': { + 'test_logger': { + 'level': 'DEBUG', + 'handlers': { + 'class': 'logging.StreamHandler' + } + } + } + } + mock_get_sdv_logger_config.return_value = mock_logger_conf + mock_logger_instance = Mock() + mock_logger_instance.handlers = [] + mock_getlogger.return_value = mock_logger_instance + + # Run + get_sdv_logger('test_logger') + + # Assert + mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) + mock_logger_instance.addHandler.assert_called_once() diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py index 316ae9083..b585cb880 100644 --- a/tests/unit/logging/test_utils.py +++ b/tests/unit/logging/test_utils.py @@ -1,8 +1,7 @@ """Test ``SDV`` logging utilities.""" -import logging from unittest.mock import Mock, mock_open, patch -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger_config def test_get_sdv_logger_config(): @@ -55,31 +54,3 @@ def test_disable_single_table_logger(mock_getlogger): # Assert assert len(mock_logger.handlers) == 1 - - -@patch('sdv.logging.utils.logging.StreamHandler') -@patch('sdv.logging.utils.logging.getLogger') -@patch('sdv.logging.utils.get_sdv_logger_config') -def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler): - # Setup - mock_logger_conf = { - 'log_registry': 'local', - 'loggers': { - 'test_logger': { - 'level': 'DEBUG', - 'handlers': { - 'class': 'logging.StreamHandler' - } - } - } - } - mock_get_sdv_logger_config.return_value = mock_logger_conf - mock_logger_instance = Mock() - mock_getlogger.return_value = mock_logger_instance - - # Run - get_sdv_logger('test_logger') - - # Assert - mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) - mock_logger_instance.addHandler.assert_called_once() From e1f787e77b5294aba797a7e2c6db9905fe58cf80 Mon Sep 17 00:00:00 2001 From: John La Date: Wed, 8 May 2024 11:00:26 -0500 Subject: [PATCH 21/32] Convert integer column names to strings to allow for default column names (#1976) --- sdv/_utils.py | 2 + sdv/metadata/multi_table.py | 7 +- sdv/metadata/single_table.py | 6 ++ sdv/multi_table/base.py | 24 +++++++ sdv/single_table/base.py | 26 ++++++- sdv/single_table/utils.py | 2 +- tests/integration/multi_table/test_hma.py | 50 +++++++++++++ tests/integration/single_table/test_base.py | 38 ++++++++++ tests/unit/metadata/test_multi_table.py | 79 +++++++++++++++++++++ tests/unit/metadata/test_single_table.py | 28 ++++++++ tests/unit/multi_table/test_base.py | 71 ++++++++++++++++++ tests/unit/single_table/test_base.py | 33 ++++++++- 12 files changed, 360 insertions(+), 6 deletions(-) diff --git a/sdv/_utils.py b/sdv/_utils.py index 577600b8d..5db3955ec 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -214,6 +214,8 @@ def _validate_foreign_keys_not_null(metadata, data): invalid_tables = defaultdict(list) for table_name, table_data in data.items(): for foreign_key in metadata._get_all_foreign_keys(table_name): + if foreign_key not in table_data and int(foreign_key) in table_data: + foreign_key = int(foreign_key) if table_data[foreign_key].isna().any(): invalid_tables[table_name].append(foreign_key) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index c7aa10d30..8aebf65c5 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1028,7 +1028,12 @@ def _set_metadata_dict(self, metadata): self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - self.relationships.append(relationship) + type_safe_relationships = { + key: str(value) + if not isinstance(value, str) + else value for key, value in relationship.items() + } + self.relationships.append(type_safe_relationships) @classmethod def load_from_dict(cls, metadata_dict): diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 4f8b1db94..806bc3561 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1238,6 +1238,12 @@ def load_from_dict(cls, metadata_dict): for key in instance._KEYS: value = deepcopy(metadata_dict.get(key)) if value: + if key == 'columns': + value = { + str(key) + if not isinstance(key, str) + else key: col for key, col in value.items() + } setattr(instance, f'{key}', value) return instance diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index b9e705947..779366cb1 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -100,6 +100,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.extended_columns = defaultdict(dict) self._table_synthesizers = {} self._table_parameters = defaultdict(dict) + self._original_table_columns = {} if synthesizer_kwargs is not None: warn_message = ( 'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not ' @@ -326,6 +327,19 @@ def update_transformers(self, table_name, column_name_to_transformer): self._validate_table_name(table_name) self._table_synthesizers[table_name].update_transformers(column_name_to_transformer) + def _store_and_convert_original_cols(self, data): + list_of_changed_tables = [] + for table, dataframe in data.items(): + self._original_table_columns[table] = dataframe.columns + for column in dataframe.columns: + if isinstance(column, int): + dataframe.columns = dataframe.columns.astype(str) + list_of_changed_tables.append(table) + break + + data[table] = dataframe + return list_of_changed_tables + def preprocess(self, data): """Transform the raw data to numerical space. @@ -337,6 +351,8 @@ def preprocess(self, data): dict: A dictionary with the preprocessed data. """ + list_of_changed_tables = self._store_and_convert_original_cols(data) + self.validate(data) if self._fitted: warnings.warn( @@ -351,6 +367,9 @@ def preprocess(self, data): self._assign_table_transformers(synthesizer, table_name, table_data) processed_data[table_name] = synthesizer._preprocess(table_data) + for table in list_of_changed_tables: + data[table].columns = self._original_table_columns[table] + return processed_data def _model_tables(self, augmented_data): @@ -487,6 +506,11 @@ def sample(self, scale=1.0): total_rows += len(table) total_columns += len(table.columns) + table_columns = getattr(self, '_original_table_columns', {}) + for table in sampled_data: + if table in table_columns: + sampled_data[table].columns = table_columns[table] + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index df3cc448e..3b3b93122 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -103,6 +103,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, enforce_min_max_values=self.enforce_min_max_values, locales=self.locales, ) + self._original_columns = pd.Index([]) self._fitted = False self._random_state_set = False self._update_default_transformers() @@ -367,6 +368,16 @@ def _preprocess(self, data): self._data_processor.fit(data) return self._data_processor.transform(data) + def _store_and_convert_original_cols(self, data): + # Transform in place to avoid possible large copy of data + for column in data.columns: + if isinstance(column, int): + self._original_columns = data.columns + data.columns = data.columns.astype(str) + return True + + return False + def preprocess(self, data): """Transform the raw data to numerical space. @@ -384,7 +395,14 @@ def preprocess(self, data): "please refit the model using 'fit' or 'fit_processed_data'." ) - return self._preprocess(data) + is_converted = self._store_and_convert_original_cols(data) + + preprocess_data = self._preprocess(data) + + if is_converted: + data.columns = self._original_columns + + return preprocess_data def _fit(self, processed_data): """Fit the model to the table. @@ -455,7 +473,7 @@ def fit(self, data): self._fitted = False self._data_processor.reset_sampling() self._random_state_set = False - processed_data = self._preprocess(data) + processed_data = self.preprocess(data) self.fit_processed_data(processed_data) def save(self, filepath): @@ -891,6 +909,10 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) + original_columns = getattr(self, '_original_columns', pd.Index([])) + if not original_columns.empty: + sampled_data.columns = self._original_columns + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 11043ca2f..d9bdb72d7 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -303,7 +303,7 @@ def unflatten_dict(flat): else: subdict = unflattened.setdefault(key, {}) - if subkey.isdigit(): + if subkey.isdigit() and key != 'univariates': subkey = int(subkey) inner = subdict.setdefault(subkey, {}) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 7a7d21ddb..9a4345a33 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1679,3 +1679,53 @@ def test_hma_not_fit_raises_sampling_error(): ) with pytest.raises(SamplingError, match=error_msg): synthesizer.sample(1) + + +def test_fit_and_sample_numerical_col_names(): + """Test fitting/sampling when column names are integers""" + # Setup + num_rows = 50 + num_cols = 10 + num_tables = 2 + data = {} + for i in range(num_tables): + values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)} + data[str(i)] = pd.DataFrame(values) + + primary_key = pd.DataFrame({1: range(num_rows)}) + primary_key_2 = pd.DataFrame({2: range(num_rows)}) + data['0'][1] = primary_key + data['1'][1] = primary_key + data['1'][2] = primary_key_2 + metadata = MultiTableMetadata() + metadata_dict = {'tables': {}} + for table_idx in range(num_tables): + metadata_dict['tables'][str(table_idx)] = {'columns': {}} + for i in range(num_cols): + metadata_dict['tables'][str(table_idx)]['columns'][i] = {'sdtype': 'numerical'} + metadata_dict['tables']['0']['columns'][1] = {'sdtype': 'id'} + metadata_dict['tables']['1']['columns'][2] = {'sdtype': 'id'} + metadata_dict['relationships'] = [ + { + 'parent_table_name': '0', + 'parent_primary_key': 1, + 'child_table_name': '1', + 'child_foreign_key': 2 + } + ] + metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata.set_primary_key('0', '1') + + # Run + synth = HMASynthesizer(metadata) + synth.fit(data) + first_sample = synth.sample() + second_sample = synth.sample() + assert first_sample['0'].columns.tolist() == data['0'].columns.tolist() + assert first_sample['1'].columns.tolist() == data['1'].columns.tolist() + assert second_sample['0'].columns.tolist() == data['0'].columns.tolist() + assert second_sample['1'].columns.tolist() == data['1'].columns.tolist() + + # Assert + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 19fe6207a..1abb62f05 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -779,6 +779,44 @@ def test_fit_raises_version_error(): instance.fit(data) +SYNTHESIZERS_CLASSES = [ + pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'), + pytest.param(TVAESynthesizer, id='TVAESynthesizer'), + pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'), + pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'), +] + + +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_fit_and_sample_numerical_col_names(synthesizer_class): + """Test fitting/sampling when column names are integers""" + # Setup + num_rows = 50 + num_cols = 10 + values = { + i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols) + } + data = pd.DataFrame(values) + metadata = SingleTableMetadata() + metadata_dict = {'columns': {}} + for i in range(num_cols): + metadata_dict['columns'][i] = {'sdtype': 'numerical'} + metadata = SingleTableMetadata.load_from_dict(metadata_dict) + + # Run + synth = synthesizer_class(metadata) + synth.fit(data) + sample_1 = synth.sample(10) + sample_2 = synth.sample(10) + + assert sample_1.columns.tolist() == data.columns.tolist() + assert sample_2.columns.tolist() == data.columns.tolist() + + # Assert + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(sample_1, sample_2) + + @pytest.mark.parametrize('synthesizer', SYNTHESIZERS) def test_sample_not_fitted(synthesizer): """Test that a synthesizer raises an error when trying to sample without fitting.""" diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 58d0b6975..c32ed7185 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1720,6 +1720,85 @@ def test_load_from_dict(self, mock_singletablemetadata): } ] + @patch('sdv.metadata.multi_table.SingleTableMetadata') + def test_load_from_dict_integer(self, mock_singletablemetadata): + """Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. Make sure that integers passed in are + turned into strings to ensure metadata is properly typed. + + Setup: + - A dict representing a ``MultiTableMetadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Output: + - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. + + Side Effects: + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'tables': { + 'accounts': { + 1: {'sdtype': 'numerical'}, + 2: {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + }, + 'branches': { + 1: {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + } + }, + 'relationships': [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 1, + 'child_table_name': 'branches', + 'child_foreign_key': 1, + } + ] + } + + single_table_accounts = { + '1': {'sdtype': 'numerical'}, + '2': {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + } + single_table_branches = { + '1': {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + } + mock_singletablemetadata.load_from_dict.side_effect = [ + single_table_accounts, + single_table_branches + ] + + # Run + instance = MultiTableMetadata.load_from_dict(multitable_metadata) + + # Assert + assert instance.tables == { + 'accounts': single_table_accounts, + 'branches': single_table_branches + } + + assert instance.relationships == [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': '1', + 'child_table_name': 'branches', + 'child_foreign_key': '1', + } + ] + @patch('sdv.metadata.multi_table.json') def test___repr__(self, mock_json): """Test that the ``__repr__`` method. diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index d51b08ecc..8fb8eb058 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -2695,6 +2695,34 @@ def test_load_from_dict(self): assert instance.sequence_index is None assert instance._version == 'SINGLE_TABLE_V1' + def test_load_from_dict_integer(self): + """Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects. + + If the metadata dict contains columns with integers for certain reasons + (e.g. due to missing column names from CSV) make sure they are correctly typed + to strings to ensure metadata is parsed properly. + """ + # Setup + my_metadata = { + 'columns': {1: 'value'}, + 'primary_key': 'pk', + 'alternate_keys': [], + 'sequence_key': None, + 'sequence_index': None, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + } + + # Run + instance = SingleTableMetadata.load_from_dict(my_metadata) + + # Assert + assert instance.columns == {'1': 'value'} + assert instance.primary_key == 'pk' + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None + assert instance._version == 'SINGLE_TABLE_V1' + @patch('sdv.metadata.utils.Path') def test_load_from_json_path_does_not_exist(self, mock_path): """Test the ``load_from_json`` method. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 9f5548330..ffcd63148 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -780,6 +780,77 @@ def test_preprocess(self): synth_upravna_enota._preprocess.assert_called_once_with(data['upravna_enota']) synth_upravna_enota.update_transformers.assert_called_once_with({'a': None, 'b': None}) + def test_preprocess_int_columns(self): + """Test the preprocess method. + + Ensure that data with column names as integers are not changed by + preprocess. + """ + # Setup + metadata_dict = { + 'tables': { + 'first_table': { + 'primary_key': '1', + 'columns': { + '1': {'sdtype': 'id'}, + '2': {'sdtype': 'categorical'}, + 'str': {'sdtype': 'categorical'} + } + }, + 'second_table': { + 'columns': { + '3': {'sdtype': 'id'}, + 'str': {'sdtype': 'categorical'} + } + } + }, + 'relationships': [ + { + 'parent_table_name': 'first_table', + 'parent_primary_key': '1', + 'child_table_name': 'second_table', + 'child_foreign_key': '3' + } + ] + } + metadata = MultiTableMetadata.load_from_dict(metadata_dict) + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + instance._table_synthesizers = { + 'first_table': Mock(), + 'second_table': Mock() + } + multi_data = { + 'first_table': pd.DataFrame({ + 1: ['abc', 'def', 'ghi'], + 2: ['x', 'a', 'b'], + 'str': ['John', 'Doe', 'John Doe'], + }), + 'second_table': pd.DataFrame({ + 3: ['abc', 'def', 'ghi'], + 'another': ['John', 'Doe', 'John Doe'], + }), + } + + # Run + instance.preprocess(multi_data) + + # Assert + corrected_frame = { + 'first_table': pd.DataFrame({ + 1: ['abc', 'def', 'ghi'], + 2: ['x', 'a', 'b'], + 'str': ['John', 'Doe', 'John Doe'], + }), + 'second_table': pd.DataFrame({ + 3: ['abc', 'def', 'ghi'], + 'another': ['John', 'Doe', 'John Doe'], + }), + } + + pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table']) + pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table']) + @patch('sdv.multi_table.base.warnings') def test_preprocess_warning(self, mock_warnings): """Test that ``preprocess`` warns the user if the model has already been fitted.""" diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 932fa52da..dbaf35018 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -324,6 +324,7 @@ def test_preprocess(self, mock_warnings): # Setup instance = Mock() instance._fitted = True + instance._store_and_convert_original_cols.return_value = False data = pd.DataFrame({ 'name': ['John', 'Doe', 'John Doe'] }) @@ -339,6 +340,34 @@ def test_preprocess(self, mock_warnings): mock_warnings.warn.assert_called_once_with(expected_warning) instance._preprocess.assert_called_once_with(data) + def test_preprocess_int_columns(self): + """Test the preprocess method. + + Ensure that data with column names as integers are not changed by + preprocess. + """ + # Setup + instance = Mock() + instance._fitted = False + instance._original_columns = pd.Index([1, 2, 'str']) + data = pd.DataFrame({ + 1: ['John', 'Doe', 'John Doe'], + 2: ['John', 'Doe', 'John Doe'], + 'str': ['John', 'Doe', 'John Doe'], + }) + + # Run + BaseSingleTableSynthesizer.preprocess(instance, data) + + # Assert + corrected_frame = pd.DataFrame({ + 1: ['John', 'Doe', 'John Doe'], + 2: ['John', 'Doe', 'John Doe'], + 'str': ['John', 'Doe', 'John Doe'], + }) + + pd.testing.assert_frame_equal(data, corrected_frame) + @patch('sdv.single_table.base.DataProcessor') def test__fit(self, mock_data_processor): """Test that ``NotImplementedError`` is being raised.""" @@ -429,8 +458,8 @@ def test_fit(self, mock_datetime, caplog): # Assert assert instance._random_state_set is False instance._data_processor.reset_sampling.assert_called_once_with() - instance._preprocess.assert_called_once_with(data) - instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) + instance.preprocess.assert_called_once_with(data) + instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() assert caplog.messages[0] == ( '\nFit:\n' From 172e71256b76d550470d26d5a9cf1a89e7eb8e21 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Thu, 9 May 2024 00:04:41 +0200 Subject: [PATCH 22/32] Issue 1995 update code to remove futurewarning related to enforce uniqueness (#1997) --- sdv/data_processing/data_processor.py | 26 +++++---- .../data_processing/test_data_processor.py | 56 +++++++------------ 2 files changed, 35 insertions(+), 47 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 5c1bff886..638a922ce 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -412,7 +412,7 @@ def _update_transformers_by_sdtypes(self, sdtype, transformer): self._transformers_by_sdtype[sdtype] = transformer @staticmethod - def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness, + def create_anonymized_transformer(sdtype, column_metadata, cardinality_rule, locales=['en_US']): """Create an instance of an ``AnonymizedFaker``. @@ -424,9 +424,11 @@ def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness, Sematic data type or a ``Faker`` function name. column_metadata (dict): A dictionary representing the rest of the metadata for the given ``sdtype``. - enforce_uniqueness (bool): - If ``True`` overwrite ``enforce_uniqueness`` with ``True`` to ensure unique - generation for primary keys. + cardinality_rule (str): + If ``'unique'`` enforce that every created value is unique. + If ``'match'`` match the cardinality of the data seen during fit. + If ``None`` do not consider cardinality. + Defaults to ``None``. locales (str or list): Locale or list of locales to use for the AnonymizedFaker transfomer. Defaults to ['en_US']. @@ -434,14 +436,14 @@ def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness, Returns: Instance of ``rdt.transformers.pii.AnonymizedFaker``. """ - kwargs = {'locales': locales} + kwargs = { + 'locales': locales, + 'cardinality_rule': cardinality_rule + } for key, value in column_metadata.items(): if key not in ['pii', 'sdtype']: kwargs[key] = value - if enforce_uniqueness: - kwargs['enforce_uniqueness'] = True - try: transformer = get_anonymized_transformer(sdtype, kwargs) except AttributeError as error: @@ -494,7 +496,7 @@ def _get_transformer_instance(self, sdtype, column_metadata): is_baseprovider = transformer.provider_name == 'BaseProvider' if is_lexify and is_baseprovider: # Default settings return self.create_anonymized_transformer( - sdtype, column_metadata, False, self._locales + sdtype, column_metadata, None, self._locales ) kwargs = { @@ -598,11 +600,11 @@ def _create_config(self, data, columns_created_by_constraints): elif pii: sdtypes[column] = 'pii' - enforce_uniqueness = bool(column in self._keys) + cardinality_rule = 'unique' if bool(column in self._keys) else None transformers[column] = self.create_anonymized_transformer( sdtype, column_metadata, - enforce_uniqueness, + cardinality_rule, self._locales ) @@ -614,7 +616,7 @@ def _create_config(self, data, columns_created_by_constraints): transformers[column] = self.create_anonymized_transformer( sdtype=sdtype, column_metadata=column_metadata, - enforce_uniqueness=True, + cardinality_rule='unique', locales=self._locales ) diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 50cd2751d..6fec93773 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -987,40 +987,23 @@ def test_create_regex_generator_regex_generator(self, mock_rdt): ) @patch('sdv.data_processing.data_processor.get_anonymized_transformer') - def test_create_anonymized_transformer_enforce_uniqueness(self, - mock_get_anonymized_transformer): - """Test the ``create_regex_generator`` method. - - Test that when given an ``sdtype`` and ``column_metadata`` that does not contain a - ``regex_format`` this calls ``create_anonymized_transformer`` with ``enforce_uniqueness`` - set to ``True``. - - Input: - - String representing an ``sdtype``. - - Dictionary with ``column_metadata`` that contains ``sdtype``. - - Mock: - - Mock the ``create_anonymized_transformer``. + def test_create_anonymized_transformer_cardinality_rule_unique( + self, mock_get_anonymized_transformer): + """Test the ``create_anonymized_transformer`` method. - Output: - - The return value of ``create_anonymized_transformer``. + Test that when calling with ``cardinality_rule`` set to ``'unique'``, this + calls ``get_anonymized_transformer`` with the given parameters. """ # Setup sdtype = 'ssn' - column_metadata = { - 'sdtype': 'ssn', - } + column_metadata = {'sdtype': 'ssn'} # Run - output = DataProcessor.create_anonymized_transformer( - sdtype, - column_metadata, - True - ) + output = DataProcessor.create_anonymized_transformer(sdtype, column_metadata, 'unique') # Assert mock_get_anonymized_transformer.assert_called_once_with( - 'ssn', {'enforce_uniqueness': True, 'locales': ['en_US']} + 'ssn', {'cardinality_rule': 'unique', 'locales': ['en_US']} ) assert output == mock_get_anonymized_transformer.return_value @@ -1033,21 +1016,19 @@ def test_create_anonymized_transformer_locales(self, mock_get_anonymized_transfo """ # Setup sdtype = 'ssn' - column_metadata = { - 'sdtype': 'ssn', - } + column_metadata = {'sdtype': 'ssn'} # Run output = DataProcessor.create_anonymized_transformer( sdtype, column_metadata, - False, + None, locales=['en_US', 'en_CA'] ) # Assert mock_get_anonymized_transformer.assert_called_once_with( - 'ssn', {'locales': ['en_US', 'en_CA']} + 'ssn', {'locales': ['en_US', 'en_CA'], 'cardinality_rule': None} ) assert output == mock_get_anonymized_transformer.return_value @@ -1069,7 +1050,7 @@ def test_create_anonymized_transformer_locales_missing_attribute(self): DataProcessor.create_anonymized_transformer( sdtype, column_metadata, - False, + None, locales=['en_UK'] ) @@ -1099,13 +1080,18 @@ def test_create_anonymized_transformer(self, mock_get_anonymized_transformer): } # Run - output = DataProcessor.create_anonymized_transformer(sdtype, column_metadata, False) + output = DataProcessor.create_anonymized_transformer(sdtype, column_metadata, 'unique') # Assert assert output == mock_get_anonymized_transformer.return_value - mock_get_anonymized_transformer.assert_called_once_with( - 'email', {'function_kwargs': {'domain': 'gmail.com'}, 'locales': ['en_US']} - ) + expected_kwargs = { + 'function_kwargs': { + 'domain': 'gmail.com' + }, + 'locales': ['en_US'], + 'cardinality_rule': 'unique' + } + mock_get_anonymized_transformer.assert_called_once_with('email', expected_kwargs) def test__get_transformer_instance_no_kwargs(self): """Test the ``_get_transformer_instance`` without keyword args. From bfa6e7242445bc19d037512621d33a914a64ed03 Mon Sep 17 00:00:00 2001 From: John La Date: Thu, 9 May 2024 18:40:55 -0500 Subject: [PATCH 23/32] Convert metadata columns from integer to columns to ensure that SDV works properly (#1989) --- sdv/metadata/multi_table.py | 6 ++ sdv/metadata/single_table.py | 7 +- tests/integration/multi_table/test_hma.py | 59 +++++++++++++++ tests/integration/single_table/test_base.py | 21 ++++++ tests/unit/metadata/test_single_table.py | 81 +++++++++++++++++++++ 5 files changed, 173 insertions(+), 1 deletion(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 8aebf65c5..a8ef5e174 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -524,6 +524,9 @@ def _detect_relationships(self): def detect_table_from_dataframe(self, table_name, data): """Detect the metadata for a table from a dataframe. + This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``, + for a specified table. All data column names are converted to strings. + Args: table_name (str): Name of the table to detect. @@ -539,6 +542,9 @@ def detect_table_from_dataframe(self, table_name, data): def detect_from_dataframes(self, data): """Detect the metadata for all tables in a dictionary of dataframes. + This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``. + All data column names are converted to strings. + Args: data (dict): Dictionary of table names to dataframes. diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 806bc3561..a91a02064 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -524,6 +524,8 @@ def _detect_columns(self, data): data (pandas.DataFrame): The data to be analyzed. """ + old_columns = data.columns + data.columns = data.columns.astype(str) first_pii_field = None for field in data: column_data = data[field] @@ -573,11 +575,13 @@ def _detect_columns(self, data): self.primary_key = first_pii_field self._updated = True + data.columns = old_columns def detect_from_dataframe(self, data): """Detect the metadata from a ``pd.DataFrame`` object. This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``. + All data column names are converted to strings. Args: data (pandas.DataFrame): @@ -1232,7 +1236,8 @@ def load_from_dict(cls, metadata_dict): Python dictionary representing a ``SingleTableMetadata`` object. Returns: - Instance of ``SingleTableMetadata``. + Instance of ``SingleTableMetadata``. Column names are converted to + string type. """ instance = cls() for key in instance._KEYS: diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 9a4345a33..bff97912e 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1729,3 +1729,62 @@ def test_fit_and_sample_numerical_col_names(): # Assert with pytest.raises(AssertionError): pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) + + +def test_detect_from_dataframe_numerical_col(): + """Test that metadata detection of integer columns work.""" + # Setup + parent_data = pd.DataFrame({ + 1: [1000, 1001, 1002], + 2: [2, 3, 4], + 'categorical_col': ['a', 'b', 'a'], + }) + child_data = pd.DataFrame({ + 3: [1000, 1001, 1000], + 4: [1, 2, 3] + }) + data = { + 'parent_data': parent_data, + 'child_data': child_data, + } + metadata = MultiTableMetadata() + metadata.detect_table_from_dataframe('parent_data', parent_data) + metadata.detect_table_from_dataframe('child_data', child_data) + metadata.update_column('parent_data', '1', sdtype='id') + metadata.update_column('child_data', '3', sdtype='id') + metadata.update_column('child_data', '4', sdtype='id') + metadata.set_primary_key('parent_data', '1') + metadata.set_primary_key('child_data', '4') + metadata.add_relationship( + parent_primary_key='1', + parent_table_name='parent_data', + child_foreign_key='3', + child_table_name='child_data' + ) + + test_metadata = MultiTableMetadata() + test_metadata.detect_from_dataframes(data) + test_metadata.update_column('parent_data', '1', sdtype='id') + test_metadata.update_column('child_data', '3', sdtype='id') + test_metadata.update_column('child_data', '4', sdtype='id') + test_metadata.set_primary_key('parent_data', '1') + test_metadata.set_primary_key('child_data', '4') + test_metadata.add_relationship( + parent_primary_key='1', + parent_table_name='parent_data', + child_foreign_key='3', + child_table_name='child_data' + ) + + # Run + instance = HMASynthesizer(metadata) + instance.fit(data) + sample = instance.sample(5) + + # Assert + assert test_metadata.to_dict() == metadata.to_dict() + assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist() + assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist() + + test_metadata = MultiTableMetadata() + test_metadata.detect_from_dataframes(data) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 1abb62f05..a9579775a 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -831,3 +831,24 @@ def test_sample_not_fitted(synthesizer): # Run and Assert with pytest.raises(SamplingError, match=expected_message): synthesizer.sample(10) + + +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_detect_from_dataframe_numerical_col(synthesizer_class): + """Test that metadata detection of integer columns work.""" + # Setup + data = pd.DataFrame({ + 1: [1, 2, 3], + 2: [4, 5, 6], + 3: ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + + # Run + metadata.detect_from_dataframe(data) + instance = synthesizer_class(metadata) + instance.fit(data) + sample = instance.sample(5) + + # Assert + assert sample.columns.tolist() == data.columns.tolist() diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index 8fb8eb058..428f5926f 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1296,6 +1296,87 @@ def test_detect_from_dataframe(self, mock_log): ] mock_log.info.assert_has_calls(expected_log_calls) + @patch('sdv.metadata.single_table.LOGGER') + def test_detect_from_dataframe_numerical_columns(self, mock_log): + """Test the detect from dataframe with columns that are integers""" + # Setup + num_rows = 100 + num_cols = 20 + values = {i + 1: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)} + data = pd.DataFrame(values) + correct_metadata = { + 'columns': { + '1': { + 'sdtype': 'numerical' + }, + '2': { + 'sdtype': 'numerical' + }, + '3': { + 'sdtype': 'numerical' + }, + '4': { + 'sdtype': 'numerical' + }, + '5': { + 'sdtype': 'numerical' + }, + '6': { + 'sdtype': 'numerical' + }, + '7': { + 'sdtype': 'numerical' + }, + '8': { + 'sdtype': 'numerical' + }, + '9': { + 'sdtype': 'numerical' + }, + '10': { + 'sdtype': 'numerical' + }, + '11': { + 'sdtype': 'numerical' + }, + '12': { + 'sdtype': 'numerical' + }, + '13': { + 'sdtype': 'numerical' + }, + '14': { + 'sdtype': 'numerical' + }, + '15': { + 'sdtype': 'numerical' + }, + '16': { + 'sdtype': 'numerical' + }, + '17': { + 'sdtype': 'numerical' + }, + '18': { + 'sdtype': 'numerical' + }, + '19': { + 'sdtype': 'numerical' + }, + '20': { + 'sdtype': 'numerical' + } + }, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + } + + # Run + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + + # Assert + assert correct_metadata == metadata.to_dict() + def test_detect_from_csv_raises_error(self): """Test the ``detect_from_csv`` method. From 4d67b9b41b342c27476ea9b7087f0087f8ddd4df Mon Sep 17 00:00:00 2001 From: SDV Team <98988753+sdv-team@users.noreply.github.com> Date: Mon, 13 May 2024 11:01:11 -0400 Subject: [PATCH 24/32] Automated Latest Dependency Updates (#1999) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- latest_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latest_requirements.txt b/latest_requirements.txt index 3c64a8cdd..f8024233b 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -6,6 +6,6 @@ graphviz==0.20.3 numpy==1.26.4 pandas==2.2.2 platformdirs==4.2.1 -rdt==1.12.0 +rdt==1.12.1 sdmetrics==0.14.0 tqdm==4.66.4 From 0840c7f414836865fc1a6ae74f0df8af8b33c2c9 Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Mon, 13 May 2024 12:32:29 -0500 Subject: [PATCH 25/32] Adding target to Makefile for pushing to git --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index c953b15b6..ff5871ffa 100644 --- a/Makefile +++ b/Makefile @@ -235,6 +235,10 @@ ifeq ($(CHANGELOG_LINES),0) $(error Please insert the release notes in HISTORY.md before releasing) endif +.PHONY: git-push +git-push: ## Simply push the repository to github + git push + .PHONY: check-release check-release: check-clean check-main check-history ## Check if the release can be made @echo "A new release can be made" From 1762452bf3bf104b723a3d9b562f9be4fc6af98e Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Mon, 13 May 2024 19:43:53 -0500 Subject: [PATCH 26/32] =?UTF-8?q?Bump=20version:=201.12.2.dev0=20=E2=86=92?= =?UTF-8?q?=201.13.0.dev0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- sdv/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3e89f63ff..ee7525c96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ namespaces = false version = {attr = 'sdv.__version__'} [tool.bumpversion] -current_version = "1.12.2.dev0" +current_version = "1.13.0.dev0" parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' serialize = [ '{major}.{minor}.{patch}.{release}{candidate}', diff --git a/sdv/__init__.py b/sdv/__init__.py index 4c29c84de..9f60022f7 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = 'DataCebo, Inc.' __email__ = 'info@sdv.dev' -__version__ = '1.12.2.dev0' +__version__ = '1.13.0.dev0' import sys From 08a464c10d82f054e00d5807f6574579d7204329 Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Mon, 13 May 2024 19:44:33 -0500 Subject: [PATCH 27/32] =?UTF-8?q?Bump=20version:=201.13.0.dev0=20=E2=86=92?= =?UTF-8?q?=201.13.0.dev1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- sdv/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ee7525c96..780c5c790 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ namespaces = false version = {attr = 'sdv.__version__'} [tool.bumpversion] -current_version = "1.13.0.dev0" +current_version = "1.13.0.dev1" parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' serialize = [ '{major}.{minor}.{patch}.{release}{candidate}', diff --git a/sdv/__init__.py b/sdv/__init__.py index 9f60022f7..21509b8b9 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = 'DataCebo, Inc.' __email__ = 'info@sdv.dev' -__version__ = '1.13.0.dev0' +__version__ = '1.13.0.dev1' import sys From 2d2b030d02a134d9648f03ac1c3f176e64a7a861 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 15 May 2024 10:09:02 -0700 Subject: [PATCH 28/32] Fix pandas DtypeWarning in download_demo (#2006) --- sdv/datasets/demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index f668ce403..43c35f4ea 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -96,7 +96,7 @@ def _get_data(modality, output_folder_name, in_memory_directory): for filename, file_ in in_memory_directory.items(): if filename.endswith('.csv'): table_name = Path(filename).stem - data[table_name] = pd.read_csv(io.StringIO(file_.decode())) + data[table_name] = pd.read_csv(io.StringIO(file_.decode()), low_memory=False) if modality != 'multi_table': data = data.popitem()[1] From e2c3514ed20d9d7aff17b4fe284b36fdf66abf3e Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Wed, 15 May 2024 14:39:08 -0400 Subject: [PATCH 29/32] Update logs to write to CSV (#2005) --- sdv/logging/__init__.py | 4 +- sdv/logging/logger.py | 31 ++++++- sdv/logging/sdv_logger_config.yml | 12 ++- sdv/logging/utils.py | 17 +++- sdv/multi_table/base.py | 125 +++++++++++---------------- sdv/single_table/base.py | 118 ++++++++++--------------- tests/unit/logging/test_logger.py | 64 +++++++++++++- tests/unit/logging/test_utils.py | 36 +++++++- tests/unit/multi_table/test_base.py | 92 ++++++++++---------- tests/unit/single_table/test_base.py | 92 ++++++++++---------- 10 files changed, 340 insertions(+), 251 deletions(-) diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index c15348231..2c5d10e88 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -1,10 +1,12 @@ """Module for configuring loggers within the SDV library.""" from sdv.logging.logger import get_sdv_logger -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger_config +from sdv.logging.utils import ( + disable_single_table_logger, get_sdv_logger_config, load_logfile_dataframe) __all__ = ( 'disable_single_table_logger', 'get_sdv_logger', 'get_sdv_logger_config', + 'load_logfile_dataframe' ) diff --git a/sdv/logging/logger.py b/sdv/logging/logger.py index 52a178638..2710566ae 100644 --- a/sdv/logging/logger.py +++ b/sdv/logging/logger.py @@ -1,11 +1,35 @@ """SDV Logger.""" - +import csv import logging from functools import lru_cache +from io import StringIO from sdv.logging.utils import get_sdv_logger_config +class CSVFormatter(logging.Formatter): + """Logging formatter to convert to CSV.""" + + def __init__(self): + super().__init__() + self.output = StringIO() + headers = [ + 'LEVEL', 'EVENT', 'TIMESTAMP', 'SYNTHESIZER CLASS NAME', 'SYNTHESIZER ID', + 'TOTAL NUMBER OF TABLES', 'TOTAL NUMBER OF ROWS', 'TOTAL NUMBER OF COLUMNS' + ] + self.writer = csv.DictWriter(self.output, headers) + + def format(self, record): # noqa: A003 + """Format the record and write to CSV.""" + row = record.msg + row['LEVEL'] = record.levelname + self.writer.writerow(row) + data = self.output.getvalue() + self.output.truncate(0) + self.output.seek(0) + return data.strip() + + @lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. @@ -38,7 +62,10 @@ def get_sdv_logger(logger_name): formatter = None config = logger_conf.get('loggers').get(logger_name) log_level = getattr(logging, config.get('level', 'INFO')) - if config.get('format'): + if config.get('formatter'): + if config.get('formatter') == 'sdv.logging.logger.CSVFormatter': + formatter = CSVFormatter() + elif config.get('format'): formatter = logging.Formatter(config.get('format')) logger.setLevel(log_level) diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml index 4b01b0c65..1d00fd1be 100644 --- a/sdv/logging/sdv_logger_config.yml +++ b/sdv/logging/sdv_logger_config.yml @@ -4,24 +4,28 @@ loggers: SingleTableSynthesizer: level: INFO propagate: false + formatter: sdv.logging.logger.CSVFormatter handlers: class: logging.FileHandler - filename: sdv_logs.log + filename: sdv_logs.csv MultiTableSynthesizer: level: INFO propagate: false + formatter: sdv.logging.logger.CSVFormatter handlers: class: logging.FileHandler - filename: sdv_logs.log + filename: sdv_logs.csv MultiTableMetadata: level: INFO propagate: false + formatter: sdv.logging.logger.CSVFormatter handlers: class: logging.FileHandler - filename: sdv_logs.log + filename: sdv_logs.csv SingleTableMetadata: level: INFO propagate: false + formatter: sdv.logging.logger.CSVFormatter handlers: class: logging.FileHandler - filename: sdv_logs.log + filename: sdv_logs.csv diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 471870649..e6c86e3ea 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -5,6 +5,7 @@ import shutil from pathlib import Path +import pandas as pd import platformdirs import yaml @@ -25,7 +26,7 @@ def get_sdv_logger_config(): for logger in logger_conf.get('loggers', {}).values(): handler = logger.get('handlers', {}) - if handler.get('filename') == 'sdv_logs.log': + if handler.get('filename') == 'sdv_logs.csv': handler['filename'] = store_path / handler['filename'] return logger_conf @@ -49,3 +50,17 @@ def disable_single_table_logger(): finally: for handler in handlers: single_table_logger.addHandler(handler) + + +def load_logfile_dataframe(logfile): + """Load the SDV logfile as a pandas DataFrame with correct column headers. + + Args: + logfile (str): + Path to the SDV log CSV file. + """ + column_names = [ + 'LEVEL', 'EVENT', 'TIMESTAMP', 'SYNTHESIZER CLASS NAME', 'SYNTHESIZER ID', + 'TOTAL NUMBER OF TABLES', 'TOTAL NUMBER OF ROWS', 'TOTAL NUMBER OF COLUMNS' + ] + return pd.read_csv(logfile, names=column_names) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 779366cb1..ed08891fb 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -119,15 +119,12 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - SYNTHESIZER_LOGGER.info( - '\nInstance:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - self._synthesizer_id - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Instance', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id + }) def set_address_columns(self, table_name, column_names, anonymization_level='full'): """Set the address multi-column transformer. @@ -403,22 +400,16 @@ def fit_processed_data(self, processed_data): total_rows += len(table) total_columns += len(table.columns) - SYNTHESIZER_LOGGER.info( - '\nFit processed data:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: %s\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(processed_data), - total_rows, - total_columns, - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Fit processed data', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': len(processed_data), + 'TOTAL NUMBER OF ROWS': total_rows, + 'TOTAL NUMBER OF COLUMNS': total_columns + }) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) with disable_single_table_logger(): augmented_data = self._augment_tables(processed_data) @@ -443,22 +434,16 @@ def fit(self, data): total_rows += len(table) total_columns += len(table.columns) - SYNTHESIZER_LOGGER.info( - '\nFit:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' - ' Total number of tables: %s\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(data), - total_rows, - total_columns, - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Fit', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': len(data), + 'TOTAL NUMBER OF ROWS': total_rows, + 'TOTAL NUMBER OF COLUMNS': total_columns + }) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() @@ -511,22 +496,16 @@ def sample(self, scale=1.0): if table in table_columns: sampled_data[table].columns = table_columns[table] - SYNTHESIZER_LOGGER.info( - '\nSample:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the sample size:\n' - ' Total number of tables: %s\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(sampled_data), - total_rows, - total_columns, - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Sample', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': len(sampled_data), + 'TOTAL NUMBER OF ROWS': total_rows, + 'TOTAL NUMBER OF COLUMNS': total_columns + }) + return sampled_data def get_learned_distributions(self, table_name): @@ -692,15 +671,13 @@ def save(self, filepath): Path where the instance will be serialized. """ synthesizer_id = getattr(self, '_synthesizer_id', None) - SYNTHESIZER_LOGGER.info( - '\nSave:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - synthesizer_id - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Save', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': synthesizer_id, + }) + with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -724,13 +701,11 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - SYNTHESIZER_LOGGER.info( - '\nLoad:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - synthesizer.__class__.__name__, - synthesizer._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Load', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': synthesizer.__class__.__name__, + 'SYNTHESIZER ID': synthesizer._synthesizer_id, + }) + return synthesizer diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 3b3b93122..de474d1c7 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -112,15 +112,12 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - SYNTHESIZER_LOGGER.info( - '\nInstance:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - self._synthesizer_id - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Instance', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + }) def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" @@ -420,21 +417,15 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - SYNTHESIZER_LOGGER.info( - '\nFit processed data:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(processed_data), - len(processed_data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Fit processed data', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': len(processed_data), + 'TOTAL NUMBER OF COLUMNS': len(processed_data.columns) + }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: @@ -452,21 +443,15 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ - SYNTHESIZER_LOGGER.info( - '\nFit:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(data), - len(data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Fit', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': len(data), + 'TOTAL NUMBER OF COLUMNS': len(data.columns) + }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() @@ -484,15 +469,12 @@ def save(self, filepath): Path where the synthesizer instance will be serialized. """ synthesizer_id = getattr(self, '_synthesizer_id', None) - SYNTHESIZER_LOGGER.info( - '\nSave:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - synthesizer_id - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Save', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': synthesizer_id, + }) with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -517,15 +499,12 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - SYNTHESIZER_LOGGER.info( - '\nLoad:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - synthesizer.__class__.__name__, - synthesizer._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Load', + 'TIMESTAMP': datetime.datetime.now(), + 'SYNTHESIZER CLASS NAME': synthesizer.__class__.__name__, + 'SYNTHESIZER ID': synthesizer._synthesizer_id, + }) return synthesizer @@ -913,21 +892,16 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file if not original_columns.empty: sampled_data.columns = self._original_columns - SYNTHESIZER_LOGGER.info( - '\nSample:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - sample_timestamp, - self.__class__.__name__, - len(sampled_data), - len(sampled_data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info({ + 'EVENT': 'Sample', + 'TIMESTAMP': sample_timestamp, + 'SYNTHESIZER CLASS NAME': self.__class__.__name__, + 'SYNTHESIZER ID': self._synthesizer_id, + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': len(sampled_data), + 'TOTAL NUMBER OF COLUMNS': len(sampled_data.columns) + + }) return sampled_data diff --git a/tests/unit/logging/test_logger.py b/tests/unit/logging/test_logger.py index 4770d39cd..1834cbf5e 100644 --- a/tests/unit/logging/test_logger.py +++ b/tests/unit/logging/test_logger.py @@ -2,7 +2,31 @@ import logging from unittest.mock import Mock, patch -from sdv.logging.logger import get_sdv_logger +from sdv.logging.logger import CSVFormatter, get_sdv_logger + + +class TestCSVFormatter: + + def test_format(self): + """Test CSV formatter correctly formats the log entry.""" + # Setup + instance = CSVFormatter() + instance.writer = Mock() + instance.output = Mock() + record = Mock() + record.msg = { + 'EVENT': 'Instance', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', + 'SYNTHESIZER ID': 'GaussainCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + } + record.levelname = 'INFO' + + # Run + instance.format(record) + + # Assert + instance.writer.writerow.assert_called_once_with({'LEVEL': 'INFO', **record.msg}) @patch('sdv.logging.logger.logging.StreamHandler') @@ -32,3 +56,41 @@ def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamh # Assert mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) mock_logger_instance.addHandler.assert_called_once() + + +@patch('sdv.logging.logger.CSVFormatter') +@patch('sdv.logging.logger.logging.FileHandler') +@patch('sdv.logging.logger.logging.getLogger') +@patch('sdv.logging.logger.get_sdv_logger_config') +def test_get_sdv_logger_csv(mock_get_sdv_logger_config, mock_getlogger, + mock_filehandler, mock_csvformatter): + # Setup + mock_logger_conf = { + 'log_registry': 'local', + 'loggers': { + 'test_logger_csv': { + 'level': 'DEBUG', + 'formatter': 'sdv.logging.logger.CSVFormatter', + 'handlers': { + 'filename': 'logfile.csv', + 'class': 'logging.FileHandler' + } + } + } + } + mock_get_sdv_logger_config.return_value = mock_logger_conf + mock_logger_instance = Mock() + mock_logger_instance.handlers = [] + mock_getlogger.return_value = mock_logger_instance + mock_filehandler_instance = Mock() + mock_filehandler.return_value = mock_filehandler_instance + + # Run + get_sdv_logger('test_logger_csv') + + # Assert + mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) + mock_logger_instance.addHandler.assert_called_once_with(mock_filehandler_instance) + mock_filehandler.assert_called_once_with('logfile.csv') + mock_filehandler_instance.setLevel.assert_called_once_with(logging.DEBUG) + mock_filehandler_instance.setFormatter.assert_called_once_with(mock_csvformatter.return_value) diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py index b585cb880..1bcf74788 100644 --- a/tests/unit/logging/test_utils.py +++ b/tests/unit/logging/test_utils.py @@ -1,7 +1,12 @@ """Test ``SDV`` logging utilities.""" +from io import StringIO from unittest.mock import Mock, mock_open, patch -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger_config +import numpy as np +import pandas as pd + +from sdv.logging.utils import ( + disable_single_table_logger, get_sdv_logger_config, load_logfile_dataframe) def test_get_sdv_logger_config(): @@ -54,3 +59,32 @@ def test_disable_single_table_logger(mock_getlogger): # Assert assert len(mock_logger.handlers) == 1 + + +def test_load_logfile_dataframe(): + """Test loading the CSV logfile into a DataFrame""" + # Setup + logfile = StringIO( + 'INFO,Instance,2024-05-14 11:29:00.649735,GaussianCopulaSynthesizer,' + 'GaussianCopulaSynthesizer_1.12.1_5387a6e9f4d,,,\n' + 'INFO,Fit,2024-05-14 11:29:00.649735,GaussianCopulaSynthesizer,' + 'GaussianCopulaSynthesizer_1.12.1_5387a6e9f4d,1,500,9\n' + 'INFO,Sample,2024-05-14 11:29:00.649735,GaussianCopulaSynthesizer,' + 'GaussianCopulaSynthesizer_1.12.1_5387a6e9f4d,1,500,6\n' + ) + + # Run + log_dataframe = load_logfile_dataframe(logfile) + + # Assert + expected_log = pd.DataFrame({ + 'LEVEL': ['INFO'] * 3, + 'EVENT': ['Instance', 'Fit', 'Sample'], + 'TIMESTAMP': ['2024-05-14 11:29:00.649735'] * 3, + 'SYNTHESIZER CLASS NAME': ['GaussianCopulaSynthesizer'] * 3, + 'SYNTHESIZER ID': ['GaussianCopulaSynthesizer_1.12.1_5387a6e9f4d'] * 3, + 'TOTAL NUMBER OF TABLES': [np.nan, 1, 1], + 'TOTAL NUMBER OF ROWS': [np.nan, 500, 500], + 'TOTAL NUMBER OF COLUMNS': [np.nan, 9, 6] + }) + pd.testing.assert_frame_equal(log_dataframe, expected_log) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index ffcd63148..c4fe83d6e 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -133,11 +133,12 @@ def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_i mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) assert instance._synthesizer_id == synthesizer_id - assert caplog.messages[0] == ( - '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' - 'BaseMultiTableSynthesizer\n Synthesizer id: ' - 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Instance', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + }) def test__init__column_relationship_warning(self): """Test that a warning is raised only once when the metadata has column relationships.""" @@ -927,16 +928,15 @@ def test_fit_processed_data(self, mock_datetime, caplog): instance._augment_tables.assert_called_once_with(processed_data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted - assert caplog.messages[0] == ( - '\nFit processed data:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 6\n' - ' Total number of columns: 4\n' - ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Fit processed data', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 2, + 'TOTAL NUMBER OF ROWS': 6, + 'TOTAL NUMBER OF COLUMNS': 4 + }) def test_fit_processed_data_empty_table(self): """Test attributes are properly set when data is empty and that _fit is not called.""" @@ -1012,16 +1012,15 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() - assert caplog.messages[0] == ( - '\nFit:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 6\n' - ' Total number of columns: 4\n' - ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Fit', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 2, + 'TOTAL NUMBER OF ROWS': 6, + 'TOTAL NUMBER OF COLUMNS': 4 + }) def test_fit_raises_version_error(self): """Test that fit will raise a ``VersionError`` if the current version is bigger.""" @@ -1148,16 +1147,15 @@ def test_sample(self, mock_datetime, caplog): # Assert instance._sample.assert_called_once_with(scale=1.5) - assert caplog.messages[0] == ( - '\nSample:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: BaseMultiTableSynthesizer\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 2\n' - ' Total number of rows: 6\n' - ' Total number of columns: 4\n' - ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Sample', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 2, + 'TOTAL NUMBER OF ROWS': 6, + 'TOTAL NUMBER OF COLUMNS': 4 + }) def test_get_learned_distributions_raises_an_unfitted_error(self): """Test that ``get_learned_distributions`` raises an error when model is not fitted.""" @@ -1563,12 +1561,12 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) - assert caplog.messages[0] == ( - '\nSave:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Save', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + }) @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @@ -1599,9 +1597,9 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) assert loaded_instance._synthesizer_id == synthesizer_id mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) - assert caplog.messages[0] == ( - '\nLoad:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Load', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + }) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index dbaf35018..3074d8506 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -94,11 +94,12 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) - assert caplog.messages[0] == ( - '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' - 'BaseSingleTableSynthesizer\n Synthesizer id: ' - 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Instance', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'BaseSingleTableSynthesizer', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + }) @patch('sdv.single_table.base.DataProcessor') def test___init__custom(self, mock_data_processor): @@ -398,16 +399,15 @@ def test_fit_processed_data(self, mock_datetime, caplog): # Assert instance._fit.assert_called_once_with(processed_data) - assert caplog.messages[0] == ( - '\nFit processed data:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 1\n' - ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Fit processed data', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': 3, + 'TOTAL NUMBER OF COLUMNS': 1 + }) def test_fit_processed_data_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -461,16 +461,15 @@ def test_fit(self, mock_datetime, caplog): instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() - assert caplog.messages[0] == ( - '\nFit:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 2\n' - ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Fit', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': 3, + 'TOTAL NUMBER OF COLUMNS': 2 + }) def test_fit_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -1476,16 +1475,15 @@ def test_sample(self, mock_datetime, caplog): show_progress_bar=True ) pd.testing.assert_frame_equal(result, pd.DataFrame({'col': [1, 2, 3]})) - assert caplog.messages[0] == ( - '\nSample:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 1\n' - ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Sample', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + 'TOTAL NUMBER OF TABLES': 1, + 'TOTAL NUMBER OF ROWS': 3, + 'TOTAL NUMBER OF COLUMNS': 1 + }) def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields.""" @@ -1855,12 +1853,12 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) - assert caplog.messages[0] == ( - '\nSave:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Save', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + }) @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.generate_synthesizer_id') @@ -1891,12 +1889,12 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war assert loaded_instance._synthesizer_id == synthesizer_id mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) - assert caplog.messages[0] == ( - '\nLoad:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: Mock\n' - ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - ) + assert caplog.messages[0] == str({ + 'EVENT': 'Load', + 'TIMESTAMP': '2024-04-19 16:20:10.037183', + 'SYNTHESIZER CLASS NAME': 'Mock', + 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + }) def test_load_custom_constraint_classes(self): """Test that ``load_custom_constraint_classes`` calls the ``DataProcessor``'s method.""" From 11e0ae17e73f34f04c6b3f4b356978f92558130c Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Wed, 15 May 2024 14:23:08 -0500 Subject: [PATCH 30/32] 1.13.0 Release Notes (#2002) --- HISTORY.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 1fce1b386..7e376891e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,40 @@ # Release Notes +## 1.13.0 - 2024-05-15 + +This release adds a utility function called `get_random_subset` that helps users get a subset of their multi-table data so that modeling can be done quicker. Given a dictionary of table names mapped to DataFrames, metadata, a main table and a desired number of rows to use for the main table, it will subsample the data in a way that maintains referential integrity. + +This release also adds two new local file handlers: the `CSVHandler` and the `ExcelHandler`. This enables users to easily load from and save synthetic data to these files types. These handlers return data and metadata in the multi-table format, so we also added the function `get_table_metadata` to get a `SingleTableMetadata` object from a `MultiTableMetadata` object. + +Finally, this release fixes some bugs that prevented synthesizers from working with data that had numerical column names. + +### New Features + +* Add `get_random_subset` poc utility function - Issue [#1877](https://github.com/sdv-dev/SDV/issues/1877) by @R-Palazzo +* Add usage logging - Issue [#1903](https://github.com/sdv-dev/SDV/issues/1903) by @pvk-developer +* Move function `drop_unknown_references` from `poc` to be directly under `utils` - Issue [#1947](https://github.com/sdv-dev/SDV/issues/1947) by @R-Palazzo +* Add CSVHandler - Issue [#1949](https://github.com/sdv-dev/SDV/issues/1949) by @pvk-developer +* Add ExcelHandler - Issue [#1950](https://github.com/sdv-dev/SDV/issues/1950) by @pvk-developer +* Add get_table_metadata function - Issue [#1951](https://github.com/sdv-dev/SDV/issues/1951) by @R-Palazzo +* Save usage log file as a csv - Issue [#1974](https://github.com/sdv-dev/SDV/issues/1974) by @frances-h +* Split out metadata creation from data import in the local files handlers - Issue [#1975](https://github.com/sdv-dev/SDV/issues/1975) by @pvk-developer +* Improve error message when trying to sample before fitting (single table) - Issue [#1978](https://github.com/sdv-dev/SDV/issues/1978) by @R-Palazzo + +### Bugs Fixed + +* Metadata detection crashes when the column names are integers (`AttributeError: 'int' object has no attribute 'lower'`) - Issue [#1933](https://github.com/sdv-dev/SDV/issues/1933) by @lajohn4747 +* Synthesizers crash when column names are integers (`TypeError: unsupported operand`) - Issue [#1935](https://github.com/sdv-dev/SDV/issues/1935) by @lajohn4747 +* Switch parameter order in drop_unknown_references - Issue [#1944](https://github.com/sdv-dev/SDV/issues/1944) by @R-Palazzo +* Fix pandas DtypeWarning in download_demo - Issue [#1980](https://github.com/sdv-dev/SDV/issues/1980) by @fealho + +### Maintenance + +* Only run unit and integration tests on oldest and latest python versions for macos - Issue [#1948](https://github.com/sdv-dev/SDV/issues/1948) by @frances-h + +### Internal + +* Update code to remove `FutureWarning` related to 'enforce_uniqueness' parameter - Issue [#1995](https://github.com/sdv-dev/SDV/issues/1995) by @pvk-developer + ## 1.12.1 - 2024-04-19 This release makes a number of changes to how id columns are generated. By default, id columns with a regex will now have their values scrambled in the output. Id columns without a regex that are numeric will be created randomly. If they're not numeric, they will have a random suffix. From e7c4626f3490087552ab9b01e6daa5b144658350 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 15 May 2024 12:40:29 -0700 Subject: [PATCH 31/32] Remove unexpected NaN values in sequence_index (#2003) --- sdv/multi_table/hma.py | 2 +- sdv/multi_table/utils.py | 4 +- sdv/sequential/par.py | 2 +- tests/integration/sequential/test_par.py | 57 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index a769c3e13..dd142b420 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -114,7 +114,7 @@ def _estimate_columns_traversal(cls, metadata, table_name, columns_per_table[table_name] += \ cls._get_num_extended_columns( metadata, child_name, table_name, columns_per_table, distributions - ) + ) visited.add(table_name) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index a44574157..3f6fd9526 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -47,10 +47,10 @@ def _get_n_order_descendants(relationships, parent_table, order): descendants = {} order_1_descendants = _get_relationships_for_parent(relationships, parent_table) descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants] - for i in range(2, order+1): + for i in range(2, order + 1): descendants[f'order_{i}'] = [] prov_descendants = [] - for child_table in descendants[f'order_{i-1}']: + for child_table in descendants[f'order_{i - 1}']: order_i_descendants = _get_relationships_for_parent(relationships, child_table) prov_descendants.extend([rel['child_table_name'] for rel in order_i_descendants]) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 4c7a80a36..859085de0 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -178,7 +178,7 @@ def _transform_sequence_index(self, data): fill_value = min(sequence_index_sequence[self._sequence_index].dropna()) sequence_index_sequence = sequence_index_sequence.fillna(fill_value) - data[self._sequence_index] = sequence_index_sequence[self._sequence_index] + data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy() data = data.merge( sequence_index_context, left_on=self._sequence_key, diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 3b442e1b5..f46da7ba5 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -1,5 +1,6 @@ import datetime +import numpy as np import pandas as pd from deepecho import load_demo @@ -192,3 +193,59 @@ def test_sythesize_sequences(tmp_path): synthesizer.validate(loaded_sample) loaded_synthesizer.validate(synthetic_data) loaded_synthesizer.validate(loaded_sample) + + +def test_par_subset_of_data(): + """Test it when the data index is not continuous GH#1973.""" + # download data + data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019',) + + # modify the data by choosing a subset of it + data_subset = data.copy() + np.random.seed(1234) + symbols = data['Symbol'].unique() + + # only select a subset of data in each sequence + for i, symbol in enumerate(symbols): + symbol_mask = data_subset['Symbol'] == symbol + data_subset = data_subset.drop( + data_subset[symbol_mask].sample(frac=i / (2 * len(symbols))).index) + + # now run PAR + synthesizer = PARSynthesizer(metadata, epochs=5, verbose=True) + synthesizer.fit(data_subset) + synthetic_data = synthesizer.sample(num_sequences=5) + + # assert that the synthetic data doesn't contain NaN values in sequence index column + assert not pd.isna(synthetic_data['Date']).any() + + +def test_par_subset_of_data_simplified(): + """Test it when the data index is not continuous for a simple dataset GH#1973.""" + # Setup + data = pd.DataFrame({ + 'id': [1, 2, 3], + 'date': ['2020-01-01', '2020-01-02', '2020-01-03'], + }) + data.index = [0, 1, 5] + metadata = SingleTableMetadata.load_from_dict({ + 'sequence_index': 'date', + 'sequence_key': 'id', + 'columns': { + 'id': { + 'sdtype': 'id', + }, + 'date': { + 'sdtype': 'datetime', + }, + }, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + }) + synthesizer = PARSynthesizer(metadata, epochs=0) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(num_sequences=50) + + # Assert + assert not pd.isna(synthetic_data['date']).any() From ee9a8bb78d5aca75f05b53577042620dd06b619e Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Wed, 15 May 2024 14:46:52 -0500 Subject: [PATCH 32/32] Updating release notes --- HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.md b/HISTORY.md index 7e376891e..a08be403d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -25,6 +25,7 @@ Finally, this release fixes some bugs that prevented synthesizers from working w * Metadata detection crashes when the column names are integers (`AttributeError: 'int' object has no attribute 'lower'`) - Issue [#1933](https://github.com/sdv-dev/SDV/issues/1933) by @lajohn4747 * Synthesizers crash when column names are integers (`TypeError: unsupported operand`) - Issue [#1935](https://github.com/sdv-dev/SDV/issues/1935) by @lajohn4747 * Switch parameter order in drop_unknown_references - Issue [#1944](https://github.com/sdv-dev/SDV/issues/1944) by @R-Palazzo +* Unexpected NaN values in sequence_index when dataframe isn't reset - Issue [#1973](https://github.com/sdv-dev/SDV/issues/1973) by @fealho * Fix pandas DtypeWarning in download_demo - Issue [#1980](https://github.com/sdv-dev/SDV/issues/1980) by @fealho ### Maintenance