diff --git a/HISTORY.md b/HISTORY.md index 89257a1b..e29bd7bb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,13 @@ # History +## v0.9.2 - 2023-03-07 +This release fixes bugs in the `NewRowSynthesis` metric when too many columns were present. It also fixes bugs around datetime columns that are formatted as strings in both `get_column_pair_plot` and `get_column_plot`. + +### Bug Fixes +* Method get_column_pair_plot: Does not plot synthetic data if datetime column is formatted as a string - Issue [#310] (https://github.com/sdv-dev/SDMetrics/issues/310) by @frances-h +* Method get_column_plot: ValueError if a datetime column is formatted as a string - Issue [#309](https://github.com/sdv-dev/SDMetrics/issues/309) by @frances-h +* Fix ValueError in the NewRowSynthesis metric (also impacts DiagnosticReport) - Issue [#307](https://github.com/sdv-dev/SDMetrics/issues/307) by @frances-h + ## v0.9.1 - 2023-02-17 This release fixes bugs in the existing metrics and reports. diff --git a/conda/meta.yaml b/conda/meta.yaml index 066326ce..21914b12 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,4 +1,4 @@ -{% set version = '0.9.1' %} +{% set version = '0.9.2.dev1' %} package: name: "{{ name|lower }}" diff --git a/sdmetrics/__init__.py b/sdmetrics/__init__.py index a3ff1f2a..36ae21ce 100644 --- a/sdmetrics/__init__.py +++ b/sdmetrics/__init__.py @@ -4,7 +4,7 @@ __author__ = 'MIT Data To AI Lab' __email__ = 'dailabmit@gmail.com' -__version__ = '0.9.1' +__version__ = '0.9.2.dev1' import pandas as pd diff --git a/sdmetrics/reports/utils.py b/sdmetrics/reports/utils.py index ddaa898f..947f7b0c 100644 --- a/sdmetrics/reports/utils.py +++ b/sdmetrics/reports/utils.py @@ -65,6 +65,49 @@ VALID_SDTYPES = ['numerical', 'categorical', 'boolean', 'datetime'] +def convert_to_datetime(column_data, datetime_format=None): + """Convert a column data to pandas datetime. + + Args: + column_data (pandas.Series): + The column data + format (str): + Optional string format of datetime. If ``None``, will attempt to infer the datetime + format from the column data. Defaults to ``None``. + + Returns: + pandas.Series: + The converted column data. + """ + if is_datetime(column_data): + return column_data + + if datetime_format is None: + datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy()) + + return pd.to_datetime(column_data, format=datetime_format) + + +def convert_datetime_columns(real_column, synthetic_column, col_metadata): + """Convert a real and a synthetic column to pandas datetime. + + Args: + real_data (pandas.Series): + The real column data + synthetic_column (pandas.Series): + The synthetic column data + col_metadata: + The metadata associated with the column + + Returns: + (pandas.Series, pandas.Series): + The converted real and synthetic column data. + """ + datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format') + return (convert_to_datetime(real_column, datetime_format), + convert_to_datetime(synthetic_column, datetime_format)) + + def make_discrete_column_plot(real_column, synthetic_column, sdtype): """Plot the real and synthetic data for a categorical or boolean column. @@ -239,9 +282,17 @@ def get_column_plot(real_data, synthetic_data, column_name, metadata): if column_name not in synthetic_data.columns: raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + column_meta = columns[column_name] sdtype = get_type_from_column_meta(columns[column_name]) - real_column = real_data[column_name] - synthetic_column = synthetic_data[column_name] + if sdtype == 'datetime': + real_column, synthetic_column = convert_datetime_columns( + real_data[column_name], + synthetic_data[column_name], + column_meta + ) + else: + real_column = real_data[column_name] + synthetic_column = synthetic_data[column_name] if sdtype in CONTINUOUS_SDTYPES: fig = make_continuous_column_plot(real_column, synthetic_column, sdtype) elif sdtype in DISCRETE_SDTYPES: @@ -252,24 +303,6 @@ def get_column_plot(real_data, synthetic_data, column_name, metadata): return fig -def convert_to_datetime(column_data): - """Convert a column data to pandas datetime. - - Args: - column_data (pandas.Series): - The column data - - Returns: - pandas.Series: - The converted column data. - """ - if is_datetime(column_data): - return column_data - - dt_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy()) - return pd.to_datetime(column_data, format=dt_format) - - def make_continuous_column_pair_plot(real_data, synthetic_data): """Make a column pair plot for continuous data. @@ -417,9 +450,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, metadata): raise ValueError(f"Column(s) `{'`, `'.join(invalid_columns)}` not found " 'in the synthetic table data.') + col_meta = (all_columns[column_names[0]], all_columns[column_names[1]]) sdtypes = ( - get_type_from_column_meta(all_columns[column_names[0]]), - get_type_from_column_meta(all_columns[column_names[1]]), + get_type_from_column_meta(col_meta[0]), + get_type_from_column_meta(col_meta[1]), ) real_data = real_data[column_names] synthetic_data = synthetic_data[column_names] @@ -432,11 +466,13 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, metadata): if all([t in DISCRETE_SDTYPES for t in sdtypes]): return make_discrete_column_pair_plot(real_data, synthetic_data) - if sdtypes[0] == 'datetime': - real_data.iloc[:, 0] = convert_to_datetime(real_data.iloc[:, 0]) - if sdtypes[1] == 'datetime': - real_data.iloc[:, 1] = convert_to_datetime(real_data.iloc[:, 1]) - + for i, sdtype in enumerate(sdtypes): + if sdtype == 'datetime': + real_data.iloc[:, i], synthetic_data.iloc[:, i] = convert_datetime_columns( + real_data.iloc[:, i], + synthetic_data.iloc[:, i], + col_meta[i] + ) if all([t in CONTINUOUS_SDTYPES for t in sdtypes]): return make_continuous_column_pair_plot(real_data, synthetic_data) else: diff --git a/sdmetrics/single_table/new_row_synthesis.py b/sdmetrics/single_table/new_row_synthesis.py index d4e55f19..c6d0b955 100644 --- a/sdmetrics/single_table/new_row_synthesis.py +++ b/sdmetrics/single_table/new_row_synthesis.py @@ -108,8 +108,11 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata=None, row_filter.append(field_filter) + engine = None + if len(row_filter) >= 32: # Limit set by NPY_MAXARGS + engine = 'python' try: - matches = real_data.query(' and '.join(row_filter)) + matches = real_data.query(' and '.join(row_filter), engine=engine) except TypeError: if len(real_data) > 10000: warnings.warn('Unable to optimize query. For better formance, set the ' diff --git a/setup.cfg b/setup.cfg index 898f4048..20964a1f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.9.1 +current_version = 0.9.2.dev1 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index d564427e..7e25b80a 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDMetrics', - version='0.9.1', + version='0.9.2.dev1', zip_safe=False, ) diff --git a/tests/unit/reports/test_utils.py b/tests/unit/reports/test_utils.py index 5de4bf7b..0d9a2368 100644 --- a/tests/unit/reports/test_utils.py +++ b/tests/unit/reports/test_utils.py @@ -223,6 +223,42 @@ def test_get_column_plot_discrete_col(make_plot_mock): assert out == make_plot_mock.return_value +@patch('sdmetrics.reports.utils.make_continuous_column_plot') +def test_get_column_plot_datetime_col(make_plot_mock): + """Test the ``get_column_plot`` method with a string datetime column.""" + # Setup + sdtype = 'datetime' + datetime_format = '%Y-%m-%d' + real_datetimes = [ + datetime(2020, 10, 1), + datetime(2020, 11, 1), + datetime(2020, 12, 1), + ] + real_data = pd.DataFrame({ + 'col1': [dt.strftime(datetime_format) for dt in real_datetimes] + }) + real_expected = pd.DataFrame({'col1': real_datetimes}) + synthetic_datetimes = [ + datetime(2021, 10, 1), + datetime(2021, 11, 1), + datetime(2021, 12, 3), + ] + synthetic_data = pd.DataFrame({ + 'col1': [dt.strftime(datetime_format) for dt in synthetic_datetimes] + }) + synthetic_expected = pd.DataFrame({'col1': synthetic_datetimes}) + metadata = {'fields': {'col1': {'type': sdtype, 'format': datetime_format}}} + + # Run + out = get_column_plot(real_data, synthetic_data, 'col1', metadata) + + # Assert + make_plot_mock.assert_called_once_with(SeriesMatcher(real_expected['col1']), + SeriesMatcher(synthetic_expected['col1']), + sdtype) + assert out == make_plot_mock.return_value + + def test_get_column_plot_invalid_sdtype(): """Test the ``get_column_plot`` method with an invalid sdtype. @@ -376,6 +412,23 @@ def test_convert_to_datetime_date_column(): pd.testing.assert_series_equal(out, expected) +def test_convert_to_datetime_str_format(): + """Test the ``convert_to_datetime`` method with a string column. + + Expect the string date column to be converted to a datetime column + using the provided format. + """ + # Setup + column_data = pd.Series(['2020-01-02', '2021-01-02']) + + # Run + out = convert_to_datetime(column_data) + + # Assert + expected = pd.Series([datetime(2020, 1, 2), datetime(2021, 1, 2)]) + pd.testing.assert_series_equal(out, expected) + + @patch('sdmetrics.reports.utils.px') def test_make_continuous_column_pair_plot(px_mock): """Test the ``make_continuous_column_pair_plot`` method. @@ -641,6 +694,60 @@ def test_get_column_pair_plot_discrete_columns(make_plot_mock): assert out == make_plot_mock.return_value +@patch('sdmetrics.reports.utils.make_mixed_column_pair_plot') +def test_get_column_pair_plot_str_datetimes(make_plot_mock): + """Test the ``get_column_pair_plot`` method with string datetime columns. + + Expect that the string datetime columns are converted to datetimes. + """ + # Setup + dt_format = '%Y-%m-%d' + real_datetimes = [ + datetime(2020, 10, 1), + datetime(2020, 11, 1), + datetime(2020, 12, 1), + ] + real_data = pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': [dt.strftime(dt_format) for dt in real_datetimes], + }) + real_expected = pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': real_datetimes, + }) + + synthetic_datetimes = [ + datetime(2021, 10, 1), + datetime(2021, 11, 1), + datetime(2021, 12, 3), + ] + synthetic_data = pd.DataFrame({ + 'col1': [2, 2, 3], + 'col2': [dt.strftime(dt_format) for dt in synthetic_datetimes], + }) + synthetic_expected = pd.DataFrame({ + 'col1': [2, 2, 3], + 'col2': synthetic_datetimes, + }) + columns = ['col1', 'col2'] + metadata = { + 'fields': { + 'col1': {'type': 'categorical'}, + 'col2': {'type': 'datetime', 'format': dt_format} + } + } + + # Run + out = get_column_pair_plot(real_data, synthetic_data, columns, metadata) + + # Assert + make_plot_mock.assert_called_once_with( + DataFrameMatcher(real_expected[columns]), + DataFrameMatcher(synthetic_expected[columns]), + ) + assert out == make_plot_mock.return_value + + def test_get_column_pair_plot_invalid_sdtype(): """Test the ``get_column_plot_pair`` method with an invalid sdtype. diff --git a/tests/unit/single_table/test_new_row_synthesis.py b/tests/unit/single_table/test_new_row_synthesis.py index 83e3d301..4a75e9bb 100644 --- a/tests/unit/single_table/test_new_row_synthesis.py +++ b/tests/unit/single_table/test_new_row_synthesis.py @@ -144,6 +144,32 @@ def test_compute_with_sample_size_too_large(self, warnings_mock): 'synthetic data rows (5). Proceeding without sampling.' ) + def test_compute_with_many_columns(self): + """Test the ``compute`` method with more than 32 columns. + + Expect that the new row synthesis is returned. + """ + # Setup + num_cols = 32 + real_data = pd.DataFrame({ + f'col{i}': list(np.random.uniform(low=0, high=10, size=100)) for i in range(num_cols) + }) + synthetic_data = pd.DataFrame({ + f'col{i}': list(np.random.uniform(low=0, high=10, size=100)) for i in range(num_cols) + }) + metadata = { + 'fields': { + f'col{i}': {'type': 'numerical', 'subtype': 'float'} for i in range(num_cols) + }, + } + metric = NewRowSynthesis() + + # Run + score = metric.compute(real_data, synthetic_data, metadata) + + # Assert + assert score == 1 + @patch('sdmetrics.single_table.new_row_synthesis.SingleTableMetric.normalize') def test_normalize(self, normalize_mock): """Test the ``normalize`` method.