Skip to content

Commit

Permalink
Some attempts of using assert_plt_calls to test multiple calls
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 28, 2024
1 parent 6d10c6d commit b557919
Showing 1 changed file with 85 additions and 10 deletions.
95 changes: 85 additions & 10 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,39 @@
"""
import os
import numpy as np
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from ensemble_md.analysis import analyze_traj

current_path = os.path.dirname(os.path.abspath(__file__))
input_path = os.path.join(current_path, "data")


def assert_plt_calls(mock_plt, call_specs):
"""
Assert calls to matplotlib functions with specified parameters.
Parameters
----------
mock_plt : :code:`MagicMock` object
Mock object of :code:`matplotlib.pyplot`.
call_specs : list
A list of lists that contains the following four elements:
- The name of the matplotlib function (as :code:`str`) that was called.
- The assert method (as :code:`str`), e.g., :code:`assert_called_once_with`.
- The positional arguments (as :code:`tuple`) passed to the matplotlib function.
- The keyword arguments (as :code:`dict`) passed to the matplotlib function.
"""
for call_spec in call_specs:
plt_func = call_spec[0]
assert_method = call_spec[1]
plt_args = call_spec[2]
plt_kwargs = call_spec[3]

mock_func = getattr(mock_plt, plt_func)
assert_func = getattr(mock_func, assert_method)
assert_func(*plt_args, **plt_kwargs) # call the assertion method


def test_extract_state_traj():
traj, t = analyze_traj.extract_state_traj(os.path.join(input_path, 'dhdl/dhdl_0.xvg'))
state_list = [
Expand Down Expand Up @@ -115,13 +141,28 @@ def test_plot_rep_trajs(mock_plt):
y_input_3 = np.array([2, 0, 1, 0, 2])

# Verify that the expected matplotlib functions were called
mock_plt.figure.assert_called_once()
mock_plt.plot.assert_called()
mock_plt.xlabel.assert_called_with('MC moves')
mock_plt.ylabel.assert_called_with('Replica')
mock_plt.grid.assert_called_once()
mock_plt.legend.assert_called_once()
mock_plt.savefig.assert_called_once_with(fig_name, dpi=600)
# mock_plt.figure.assert_called_once()
# mock_plt.plot.assert_called()
# mock_plt.xlabel.assert_called_with('MC moves')
# mock_plt.ylabel.assert_called_with('Replica')
# mock_plt.grid.assert_called_once()
# mock_plt.legend.assert_called_once()
# mock_plt.savefig.assert_called_once_with(fig_name, dpi=600)

# Using assert_plt_calls, the lines above can be written as below
call_specs = [
['figure', 'assert_called_once', (), {}],
['plot', 'assert_called', (), {}],
['xlabel', 'assert_called_with', ('MC moves',), {}],
['ylabel', 'assert_called_with', ('Replica',), {}],
['grid', 'assert_called_once', (), {}],
['legend', 'assert_called_once', (), {}],
['savefig', 'assert_called_once_with', (fig_name,), {'dpi': 600}]
]
assert_plt_calls(mock_plt, call_specs)



assert mock_plt.plot.call_count == len(trajs)

# mock_plt.plot.assert_any_call(x_input, y_input_1, color=colors[0], label='Trajectory 0')
Expand Down Expand Up @@ -196,8 +237,42 @@ def test_plot_rep_trajs(mock_plt):
assert mock_plt.plot.call_args_list[2][1] == {'color': colors[2], 'label': 'Trajectory 2'}


def test_plot_state_trajs():
pass
@patch('ensemble_md.analysis.analyze_traj.plt')
def test_plot_state_trajs(mock_plt):
state_ranges = [[0, 1, 2, 3], [2, 3, 4, 5]]
fig_name = 'ensemble_md/tests/data/test.png'
cmap = mock_plt.cm.ocean
n_sim = len(state_ranges)
colors = [cmap(i) for i in np.arange(n_sim) / n_sim]

# Mock the return value of plt.subplots to return a tuple of two mock objects
# We need this because plot_state_trajs calls _, ax = plt.subplots(...). When we mock
# matplolib.pyplot using mock_plt, plt.subplots will be replaced by mock_plt.subplots
# and will return a mock object, not the tuple of figure and axes objects that the real plt.subplots returns.
# This would in turn lead to an ValueError. To avoid this, we need to mock the return values of plt.subplots.
mock_figure = MagicMock()
mock_axes = MagicMock()
mock_plt.subplots.return_value = (mock_figure, mock_axes)

# Case 1: Short trajs without dt and stride
trajs = np.array([[0, 1, 0, 2, 3, 4, 3, 4, 5, 4], [2, 3, 4, 5, 4, 3, 2, 1, 0, 1]], dtype=int)

analyze_traj.plot_state_trajs(trajs, state_ranges, fig_name)

x_input = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

mock_plt.subplots.assert_called_once_with(nrows=1, ncols=2, figsize=(5, 2.5))
mock_plt.subplot.assert_called()
mock_plt.plot.assert_called()
mock_plt.fill_between.assert_called()
mock_plt.xlabel.assert_called_with('MC moves')
mock_plt.ylabel.assert_called_with('State')
mock_plt.grid.assert_called()

assert mock_plt.subplot.call_count == len(state_ranges)
assert mock_plt.plot.call_count == len(state_ranges)
assert mock_plt.grid.call_count == len(state_ranges)
assert mock_plt.fill_between.call_count == len(state_ranges) ** 2


def test_plot_state_hist():
Expand Down

0 comments on commit b557919

Please sign in to comment.