Skip to content

Commit

Permalink
Added a test for plot_state_hist
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 30, 2024
1 parent fa205ac commit fcb8787
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
hist_data = []
lower_bound, upper_bound = -0.5, n_states - 0.5
for traj in trajs:
# bins for different traj in trajs should be the same
hist, bins = np.histogram(traj, bins=np.arange(lower_bound, upper_bound + 1, 1))
hist_data.append(hist)
if save_hist is True:
Expand All @@ -510,6 +511,7 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
y_max = 0
for i in range(n_configs):
max_count = np.max(bottom + hist_data[i])
print(max_count)
if max_count > y_max:
y_max = max_count
plt.bar(
Expand Down
121 changes: 119 additions & 2 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def test_stitch_time_series_for_sim():
# [0, 0, 3, 1, 4, 4, 5, 4, 5, 5, 4]
# ]

# Clean up
shutil.rmtree('ensemble_md/tests/data/stitch_test')


def test_stitch_trajs():
pass
Expand Down Expand Up @@ -318,8 +321,122 @@ def test_plot_state_trajs(mock_plt):
assert mock_plt.plot.call_args_list[1][1] == {'color': colors[1], 'linewidth': 0.01}


def test_plot_state_hist():
pass
@patch('ensemble_md.analysis.analyze_traj.plt')
def test_plot_state_hist(mock_plt):
fig_name = 'ensemble_md/tests/data/test.png'
state_ranges = [[0, 1, 2, 3], [2, 3, 4, 5]]
trajs = np.array([[0, 1, 0, 2, 3, 4, 3, 4, 5, 4], [2, 3, 4, 5, 4, 3, 2, 1, 0, 1]], dtype=int)
cmap = mock_plt.cm.ocean
mock_fig = MagicMock()
mock_plt.figure.return_value = mock_fig

n_configs = 2
colors = [cmap(i) for i in np.arange(n_configs) / n_configs]
hist_data = np.array([[2, 1, 1, 2, 3, 1], [1, 2, 2, 2, 2, 1]])

# Case 1: Default settings
analyze_traj.plot_state_hist(trajs, state_ranges, fig_name)

mock_plt.figure.assert_called_once_with(figsize=(6.4, 4.8))
mock_fig.add_subplot.assert_called_once_with(111)
mock_plt.xticks.assert_called_once_with(range(6))
mock_plt.xlim.assert_called_once_with([-0.5, 5.5])
mock_plt.ylim.assert_called_once_with([0, 5.25]) # y_max = (2 + 3) * 1.05
mock_plt.xlabel.assert_called_once_with('State index')
mock_plt.ylabel.assert_called_once_with('Count')
mock_plt.grid.assert_called_once()
mock_plt.legend.assert_called_once()
mock_plt.tight_layout.assert_called_once()
mock_plt.savefig.assert_called_once_with(fig_name, dpi=600)

assert mock_plt.bar.call_count == n_configs
assert mock_plt.fill_betweenx.call_count == n_configs
assert mock_plt.fill_betweenx.call_args_list[0] == (([0, 5.25],), {'x1': 3.5, 'x2': -1.0, 'color': colors[0], 'alpha': 0.1, 'zorder': 0}) # noqa: E501
assert mock_plt.fill_betweenx.call_args_list[1] == (([0, 5.25],), {'x1': 6.0, 'x2': 1.5, 'color': colors[1], 'alpha': 0.1, 'zorder': 0}) # noqa: E501
assert mock_plt.bar.call_args_list[0][0][0] == range(6)
np.testing.assert_array_equal(mock_plt.bar.call_args_list[0][0][1], hist_data[0])
assert mock_plt.bar.call_args_list[1][0][0] == range(6)
np.testing.assert_array_equal(mock_plt.bar.call_args_list[1][0][1], hist_data[1])
assert mock_plt.bar.call_args_list[0][1] == {
'align': 'center',
'width': 1,
'color': colors[0],
'edgecolor': 'black',
'label': 'Trajectory 0',
'alpha': 0.5,
'bottom': [0, 0, 0, 0, 0, 0]
}
assert mock_plt.bar.call_args_list[1][1] == {
'align': 'center',
'width': 1,
'color': colors[1],
'edgecolor': 'black',
'label': 'Trajectory 1',
'alpha': 0.5,
'bottom': [2, 1, 1, 2, 3, 1]
}

# Case 2: max(trajs[-1]) > 30, in which case we can just test the figsize
trajs_ = np.random.randint(low=29, high=50, size=(2, 200))
mock_plt.reset_mock()

analyze_traj.plot_state_hist(trajs_, state_ranges, fig_name)
mock_plt.figure.assert_called_once_with(figsize=(10, 4.8))

# Case 3: subplots=True
mock_plt.reset_mock()
mock_figure = MagicMock()
mock_axes = MagicMock()
mock_plt.subplots.return_value = (mock_figure, mock_axes)

analyze_traj.plot_state_hist(trajs, state_ranges, fig_name, subplots=True)

n_rows, n_cols = 1, 2
mock_plt.figure.assert_called_once_with(figsize=(6.4, 4.8))
mock_plt.subplots.assert_called_once_with(nrows=n_rows, ncols=n_cols, figsize=(8, 3))
mock_plt.xlabel.assert_called_with('State index')
mock_plt.ylabel.assert_called_with('Count')
mock_plt.tight_layout.assert_called_once()
mock_plt.savefig.assert_called_once_with(fig_name, dpi=600)

assert mock_plt.subplot.call_count == n_configs
assert mock_plt.subplot.call_args_list[0][0] == (n_rows, n_cols, 1)
assert mock_plt.subplot.call_args_list[1][0] == (n_rows, n_cols, 2)
assert mock_plt.bar.call_count == n_configs
assert mock_plt.xticks.call_count == n_configs
assert mock_plt.xlim.call_count == n_configs
assert mock_plt.xlabel.call_count == n_configs
assert mock_plt.ylabel.call_count == n_configs
assert mock_plt.title.call_count == n_configs
assert mock_plt.grid.call_count == n_configs

assert mock_plt.xticks.call_args_list[0][0] == ([0, 1, 2, 3],)
assert mock_plt.xticks.call_args_list[1][0] == ([2, 3, 4, 5],)
assert mock_plt.xticks.call_args_list[0][1] == {'fontsize': 8}
assert mock_plt.xticks.call_args_list[1][1] == {'fontsize': 8}
assert mock_plt.xlim.call_args_list[0][0] == ([-0.5, 3.5],)
assert mock_plt.xlim.call_args_list[1][0] == ([1.5, 5.5],)
assert mock_plt.title.call_args_list[0][0] == ('Trajectory 0',)
assert mock_plt.title.call_args_list[1][0] == ('Trajectory 1',)
assert mock_plt.bar.call_args_list[0][0][0] == [0, 1, 2, 3]
assert mock_plt.bar.call_args_list[1][0][0] == [2, 3, 4, 5]
np.testing.assert_array_equal(mock_plt.bar.call_args_list[0][0][1], hist_data[0][[0, 1, 2, 3]])
np.testing.assert_array_equal(mock_plt.bar.call_args_list[1][0][1], hist_data[1][[2, 3, 4, 5]])
assert mock_plt.bar.call_args_list[0][1] == {
'align': 'center',
'width': 1,
'edgecolor': 'black',
'alpha': 0.5,
}
assert mock_plt.bar.call_args_list[1][1] == {
'align': 'center',
'width': 1,
'edgecolor': 'black',
'alpha': 0.5,
}

# Clean up
os.remove('hist_data.npy')


def test_calculate_hist_rmse():
Expand Down

0 comments on commit fcb8787

Please sign in to comment.