Skip to content

Commit

Permalink
Some intermediate work for test_plot_g_vecs
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 3, 2024
1 parent 4610175 commit 460f3e8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
6 changes: 3 additions & 3 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
t_k0 = list(np.array(t_k0) * dt) # units: ps
t_roundtrip = list(np.array(t_roundtrip) * dt) # units: ps
if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty
if np.max([t_0k, t_k0, t_roundtrip]) > t_max:

if np.max(list(chain.from_iterable([t_0k, t_k0, t_roundtrip]))) > t_max:
t_max = np.max([t_0k, t_k0, t_roundtrip])

if t_max >= 10000:
Expand Down Expand Up @@ -745,8 +746,7 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
for i in range(len(t_list)): # t_list[i] is the list for trajectory i
plt.plot(np.arange(len(t_list[i])) + 1, t_list[i], label=f'Trajectory {i}', marker=marker)

flattened_t_list = list(chain.from_iterable(t_list))
if np.max(flattened_t_list) >= 10000:
if np.max(list(chain.from_iterable(t_list))) >= 10000:
plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
plt.xlabel('Event index')
plt.ylabel(f'{y_labels[t]}')
Expand Down
34 changes: 32 additions & 2 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,38 @@ def test_plot_transit_time(mock_plt):
mock_plt.savefig.assert_not_called()


def test_plot_g_vecs():
pass
@patch('ensemble_md.analysis.analyze_traj.plt')
def test_plot_g_vecs(mock_plt):
cmap = mock_plt.cm.ocean
mock_ax = MagicMock()
mock_plt.gca.return_value = mock_ax

# Case 1: Short g_vecs with refs and with plot_rmse = True
g_vecs = np.array([[0, 10, 20, 30], [0, 8, 18, 28]])
refs = np.array([0, 8, 18, 28])
refs_err = np.array([0.1, 0.1, 0.1, 0.1])

analyze_traj.plot_g_vecs(g_vecs, refs, refs_err, plot_rmse=True)

mock_plt.figure.assert_called()
mock_plt.plot.assert_called()
mock_plt.xlabel.assert_called_with('Iteration index')
# mock_plt.ylabel.assert_called_any('Alchemical weight (kT)')
mock_plt.xlim.assert_called()
mock_plt.grid.assert_called()
mock_plt.legend.assert_called_with(loc='center left', bbox_to_anchor=(1, 0.2))

assert mock_plt.figure.call_count == 2
assert mock_plt.plot.call_count == 4
assert mock_plt.axhline.call_count == 3
assert mock_plt.fill_between.call_count == 3
assert mock_plt.grid.call_count == 2

assert mock_plt.ylabel.call_args_list[0][0] == ('Alchemical weight (kT)',)
assert mock_plt.ylabel.call_args_list[1][0] == ('RMSE in the alchemical weights (kT)',)


# Case 2: Long g_vecs


def test_get_swaps():
Expand Down

0 comments on commit 460f3e8

Please sign in to comment.