diff --git a/tests/integration/reports/single_table/_properties/test_column_pair_trends.py b/tests/integration/reports/single_table/_properties/test_column_pair_trends.py index 5cce4445..ef6bd116 100644 --- a/tests/integration/reports/single_table/_properties/test_column_pair_trends.py +++ b/tests/integration/reports/single_table/_properties/test_column_pair_trends.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from tests.utils import get_error_type from sdmetrics.demos import load_demo from sdmetrics.reports.single_table._properties.column_pair_trends import ColumnPairTrends @@ -76,12 +77,6 @@ def test_get_score_warnings(self, recwarn): real_data['second_perc'].iloc[2] = 'a' - def get_error_type(error): - if error is not None: - colon_index = error.find(':') - return error[:colon_index] - return None - # Run column_pair_trends = ColumnPairTrends() diff --git a/tests/integration/reports/single_table/test_quality_report.py b/tests/integration/reports/single_table/test_quality_report.py index ca5efb55..50891413 100644 --- a/tests/integration/reports/single_table/test_quality_report.py +++ b/tests/integration/reports/single_table/test_quality_report.py @@ -6,6 +6,7 @@ from sdmetrics.demos import load_demo from sdmetrics.reports.single_table import QualityReport +from tests.utils import get_error_type class TestQualityReport: @@ -250,12 +251,6 @@ def test_report_end_to_end_with_errors(self): real_data['second_perc'].iloc[2] = 'a' - def get_error_type(error): - if error is not None: - colon_index = error.find(':') - return error[:colon_index] - return None - report = QualityReport() # Run @@ -346,12 +341,6 @@ def test_report_with_column_nan(self): metadata['columns']['nan_column'] = {'sdtype': 'numerical'} column_names.append('nan_column') - def get_error_type(error): - if error is not None: - colon_index = error.find(':') - return error[:colon_index] - return None - report = QualityReport() # Run diff --git a/tests/utils.py b/tests/utils.py index fa583a27..a006ef43 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -40,3 +40,10 @@ def __eq__(self, other): """Assert equality by expanding the iterator.""" assert all(x == y for x, y in zip(self.iterator, other)) return True + + +def get_error_type(error): + if error is not None: + colon_index = error.find(':') + return error[:colon_index] + return None