From e7c4626f3490087552ab9b01e6daa5b144658350 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 15 May 2024 12:40:29 -0700 Subject: [PATCH] Remove unexpected NaN values in sequence_index (#2003) --- sdv/multi_table/hma.py | 2 +- sdv/multi_table/utils.py | 4 +- sdv/sequential/par.py | 2 +- tests/integration/sequential/test_par.py | 57 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index a769c3e13..dd142b420 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -114,7 +114,7 @@ def _estimate_columns_traversal(cls, metadata, table_name, columns_per_table[table_name] += \ cls._get_num_extended_columns( metadata, child_name, table_name, columns_per_table, distributions - ) + ) visited.add(table_name) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index a44574157..3f6fd9526 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -47,10 +47,10 @@ def _get_n_order_descendants(relationships, parent_table, order): descendants = {} order_1_descendants = _get_relationships_for_parent(relationships, parent_table) descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants] - for i in range(2, order+1): + for i in range(2, order + 1): descendants[f'order_{i}'] = [] prov_descendants = [] - for child_table in descendants[f'order_{i-1}']: + for child_table in descendants[f'order_{i - 1}']: order_i_descendants = _get_relationships_for_parent(relationships, child_table) prov_descendants.extend([rel['child_table_name'] for rel in order_i_descendants]) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 4c7a80a36..859085de0 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -178,7 +178,7 @@ def _transform_sequence_index(self, data): fill_value = min(sequence_index_sequence[self._sequence_index].dropna()) sequence_index_sequence = sequence_index_sequence.fillna(fill_value) - data[self._sequence_index] = sequence_index_sequence[self._sequence_index] + data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy() data = data.merge( sequence_index_context, left_on=self._sequence_key, diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 3b442e1b5..f46da7ba5 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -1,5 +1,6 @@ import datetime +import numpy as np import pandas as pd from deepecho import load_demo @@ -192,3 +193,59 @@ def test_sythesize_sequences(tmp_path): synthesizer.validate(loaded_sample) loaded_synthesizer.validate(synthetic_data) loaded_synthesizer.validate(loaded_sample) + + +def test_par_subset_of_data(): + """Test it when the data index is not continuous GH#1973.""" + # download data + data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019',) + + # modify the data by choosing a subset of it + data_subset = data.copy() + np.random.seed(1234) + symbols = data['Symbol'].unique() + + # only select a subset of data in each sequence + for i, symbol in enumerate(symbols): + symbol_mask = data_subset['Symbol'] == symbol + data_subset = data_subset.drop( + data_subset[symbol_mask].sample(frac=i / (2 * len(symbols))).index) + + # now run PAR + synthesizer = PARSynthesizer(metadata, epochs=5, verbose=True) + synthesizer.fit(data_subset) + synthetic_data = synthesizer.sample(num_sequences=5) + + # assert that the synthetic data doesn't contain NaN values in sequence index column + assert not pd.isna(synthetic_data['Date']).any() + + +def test_par_subset_of_data_simplified(): + """Test it when the data index is not continuous for a simple dataset GH#1973.""" + # Setup + data = pd.DataFrame({ + 'id': [1, 2, 3], + 'date': ['2020-01-01', '2020-01-02', '2020-01-03'], + }) + data.index = [0, 1, 5] + metadata = SingleTableMetadata.load_from_dict({ + 'sequence_index': 'date', + 'sequence_key': 'id', + 'columns': { + 'id': { + 'sdtype': 'id', + }, + 'date': { + 'sdtype': 'datetime', + }, + }, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + }) + synthesizer = PARSynthesizer(metadata, epochs=0) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(num_sequences=50) + + # Assert + assert not pd.isna(synthetic_data['date']).any()