Skip to content

Commit

Permalink
Print scores after every progress bar (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Mar 26, 2024
1 parent 1844a99 commit 5fe4551
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
21 changes: 10 additions & 11 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,8 @@ def _print_results(self, verbose):
"""
if verbose:
sys.stdout.write(
f'\nOverall Score: {round(self._overall_score * 100, 2)}%\n\n'
f'Overall Score (Average): {round(self._overall_score * 100, 2)}%\n\n'
)
sys.stdout.write('Properties:\n')

for property_name, property_instance in self._properties.items():
property_score = round(property_instance._compute_average() * 100, 2)
sys.stdout.write(
f'- {property_name}: {property_score}%\n'
)

def generate(self, real_data, synthetic_data, metadata, verbose=True):
"""Generate report.
Expand Down Expand Up @@ -173,15 +166,19 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
scores = []
progress_bar = None
if verbose:
sys.stdout.write('Generating report ...\n')
sys.stdout.write('Generating report ...\n\n')

start_time = time.time()
for ind, (property_name, property_instance) in enumerate(self._properties.items()):
if verbose:
num_iterations = int(property_instance._get_num_iterations(metadata))
progress_bar = tqdm.tqdm(total=num_iterations, file=sys.stdout)
progress_bar = tqdm.tqdm(
total=num_iterations,
file=sys.stdout,
bar_format='{desc}|{bar}{r_bar}|'
)
progress_bar.set_description(
f'({ind + 1}/{len(self._properties)}) Evaluating {property_name}: '
f'({ind + 1}/{len(self._properties)}) Evaluating {property_name}'
)

score = self._properties[property_name].get_score(
Expand All @@ -190,6 +187,8 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
scores.append(score)
if verbose:
progress_bar.close()
sys.stdout.write(f'{property_name} Score: {round(score * 100, 2)}%\n\n')
sys.stdout.flush()

self._overall_score = np.nanmean(scores)
self.is_generated = True
Expand Down
20 changes: 14 additions & 6 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,9 @@ def test_print_results_verbose_true(self, mock_write):

# Assert
calls = [
call('\nOverall Score: 50.0%\n\n'),
call('Properties:\n'),
call('- Column Shapes: 60.0%\n'),
call('- Column Pair Trends: 40.0%\n'),
call('Overall Score (Average): 50.0%\n\n')
]
assert mock_write.call_count == 1
mock_write.assert_has_calls(calls, any_order=True)

@patch('sys.stdout.write')
Expand Down Expand Up @@ -450,8 +448,9 @@ def test_generate_multi_table_details(self, version_mock, time_mock, datetime_mo
}
assert base_report.report_info == expected_info

@patch('sys.stdout.write')
@patch('tqdm.tqdm')
def test_generate_verbose(self, mock_tqdm):
def test_generate_verbose(self, mock_tqdm, mock_write):
"""Test the ``generate`` method with verbose=True."""
# Setup
base_report = BaseReport()
Expand Down Expand Up @@ -493,7 +492,16 @@ def test_generate_verbose(self, mock_tqdm):
base_report.generate(real_data, synthetic_data, metadata, verbose=True)

# Assert
calls = [call(total=4, file=sys.stdout), call(total=6, file=sys.stdout)]
write_calls = [
call('Property 1 Score: 100.0%\n\n'),
call('Property 2 Score: 100.0%\n\n'),
]
mock_write.assert_has_calls(write_calls, any_order=True)

calls = [
call(total=4, bar_format='{desc}|{bar}{r_bar}|', file=sys.stdout),
call(total=6, bar_format='{desc}|{bar}{r_bar}|', file=sys.stdout)
]
mock_tqdm.assert_has_calls(calls, any_order=True)
base_report._print_results.assert_called_once_with(True)

Expand Down

0 comments on commit 5fe4551

Please sign in to comment.