diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index 3e448fd9..9c7de5c1 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -694,7 +694,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): t_k0 = list(np.array(t_k0) * dt) # units: ps t_roundtrip = list(np.array(t_roundtrip) * dt) # units: ps if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty - if np.max([t_0k, t_k0, t_roundtrip]) > t_max: + + if np.max(list(chain.from_iterable([t_0k, t_k0, t_roundtrip]))) > t_max: t_max = np.max([t_0k, t_k0, t_roundtrip]) if t_max >= 10000: @@ -745,8 +746,7 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): for i in range(len(t_list)): # t_list[i] is the list for trajectory i plt.plot(np.arange(len(t_list[i])) + 1, t_list[i], label=f'Trajectory {i}', marker=marker) - flattened_t_list = list(chain.from_iterable(t_list)) - if np.max(flattened_t_list) >= 10000: + if np.max(list(chain.from_iterable(t_list))) >= 10000: plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) plt.xlabel('Event index') plt.ylabel(f'{y_labels[t]}') diff --git a/ensemble_md/tests/test_analyze_traj.py b/ensemble_md/tests/test_analyze_traj.py index 5f615178..cd3f29dc 100644 --- a/ensemble_md/tests/test_analyze_traj.py +++ b/ensemble_md/tests/test_analyze_traj.py @@ -528,8 +528,38 @@ def test_plot_transit_time(mock_plt): mock_plt.savefig.assert_not_called() -def test_plot_g_vecs(): - pass +@patch('ensemble_md.analysis.analyze_traj.plt') +def test_plot_g_vecs(mock_plt): + cmap = mock_plt.cm.ocean + mock_ax = MagicMock() + mock_plt.gca.return_value = mock_ax + + # Case 1: Short g_vecs with refs and with plot_rmse = True + g_vecs = np.array([[0, 10, 20, 30], [0, 8, 18, 28]]) + refs = np.array([0, 8, 18, 28]) + refs_err = np.array([0.1, 0.1, 0.1, 0.1]) + + analyze_traj.plot_g_vecs(g_vecs, refs, refs_err, plot_rmse=True) + + mock_plt.figure.assert_called() + mock_plt.plot.assert_called() + mock_plt.xlabel.assert_called_with('Iteration index') + # mock_plt.ylabel.assert_called_any('Alchemical weight (kT)') + mock_plt.xlim.assert_called() + mock_plt.grid.assert_called() + mock_plt.legend.assert_called_with(loc='center left', bbox_to_anchor=(1, 0.2)) + + assert mock_plt.figure.call_count == 2 + assert mock_plt.plot.call_count == 4 + assert mock_plt.axhline.call_count == 3 + assert mock_plt.fill_between.call_count == 3 + assert mock_plt.grid.call_count == 2 + + assert mock_plt.ylabel.call_args_list[0][0] == ('Alchemical weight (kT)',) + assert mock_plt.ylabel.call_args_list[1][0] == ('RMSE in the alchemical weights (kT)',) + + + # Case 2: Long g_vecs def test_get_swaps():