Skip to content

Commit

Permalink
Move out util function
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jun 26, 2024
1 parent ee577ec commit 46b8db0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from tests.utils import get_error_type

from sdmetrics.demos import load_demo
from sdmetrics.reports.single_table._properties.column_pair_trends import ColumnPairTrends
Expand Down Expand Up @@ -76,12 +77,6 @@ def test_get_score_warnings(self, recwarn):

real_data['second_perc'].iloc[2] = 'a'

def get_error_type(error):
if error is not None:
colon_index = error.find(':')
return error[:colon_index]
return None

# Run
column_pair_trends = ColumnPairTrends()

Expand Down
13 changes: 1 addition & 12 deletions tests/integration/reports/single_table/test_quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sdmetrics.demos import load_demo
from sdmetrics.reports.single_table import QualityReport
from tests.utils import get_error_type


class TestQualityReport:
Expand Down Expand Up @@ -250,12 +251,6 @@ def test_report_end_to_end_with_errors(self):

real_data['second_perc'].iloc[2] = 'a'

def get_error_type(error):
if error is not None:
colon_index = error.find(':')
return error[:colon_index]
return None

report = QualityReport()

# Run
Expand Down Expand Up @@ -346,12 +341,6 @@ def test_report_with_column_nan(self):
metadata['columns']['nan_column'] = {'sdtype': 'numerical'}
column_names.append('nan_column')

def get_error_type(error):
if error is not None:
colon_index = error.find(':')
return error[:colon_index]
return None

report = QualityReport()

# Run
Expand Down
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ def __eq__(self, other):
"""Assert equality by expanding the iterator."""
assert all(x == y for x, y in zip(self.iterator, other))
return True


def get_error_type(error):
if error is not None:
colon_index = error.find(':')
return error[:colon_index]
return None

0 comments on commit 46b8db0

Please sign in to comment.