diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index c2ce555c..3e448fd9 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -10,8 +10,10 @@ """ The :obj:`.analyze_traj` module provides methods for analyzing trajectories in REXEE. """ +import copy import numpy as np import matplotlib.pyplot as plt +from itertools import chain from matplotlib.ticker import MaxNLocator from alchemlyb.parsing.gmx import _get_headers as get_headers @@ -511,7 +513,6 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre y_max = 0 for i in range(n_configs): max_count = np.max(bottom + hist_data[i]) - print(max_count) if max_count > y_max: y_max = max_count plt.bar( @@ -678,12 +679,13 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): last_visited = k # Here we figure out the round-trip time from t_0k and t_k0. - if len(t_0k) != len(t_k0): # then it must be len(t_0k) = len(t_k0) + 1 or len(t_k0) = len(t_0k) + 1, so we drop the last element of the larger list # noqa: E501 - if len(t_0k) > len(t_k0): - t_0k.pop() + t_0k_, t_k0_ = copy.deepcopy(t_0k), copy.deepcopy(t_k0) + if len(t_0k_) != len(t_k0_): # then it must be len(t_0k) = len(t_k0) + 1 or len(t_k0) = len(t_0k) + 1, so we drop the last element of the larger list # noqa: E501 + if len(t_0k_) > len(t_k0_): + t_0k_.pop() else: - t_k0.pop() - t_roundtrip = list(np.array(t_0k) + np.array(t_k0)) + t_k0_.pop() + t_roundtrip = list(np.array(t_0k_) + np.array(t_k0_)) if end_0_found is True and end_k_found is True: if dt is not None: @@ -711,7 +713,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): t_roundtrip_avg.append(np.mean(t_roundtrip)) if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty - if sci is False and np.max([t_0k, t_k0, t_roundtrip]) >= 10000: + flattened_list = list(chain.from_iterable([t_0k, t_k0, t_roundtrip])) + if sci is False and np.max(flattened_list) >= 10000: sci = True else: t_0k_list.append([]) @@ -742,7 +745,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'): for i in range(len(t_list)): # t_list[i] is the list for trajectory i plt.plot(np.arange(len(t_list[i])) + 1, t_list[i], label=f'Trajectory {i}', marker=marker) - if max(max((t_list))) >= 10000: + flattened_t_list = list(chain.from_iterable(t_list)) + if np.max(flattened_t_list) >= 10000: plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) plt.xlabel('Event index') plt.ylabel(f'{y_labels[t]}') diff --git a/ensemble_md/tests/test_analyze_traj.py b/ensemble_md/tests/test_analyze_traj.py index 89f2aadb..5f615178 100644 --- a/ensemble_md/tests/test_analyze_traj.py +++ b/ensemble_md/tests/test_analyze_traj.py @@ -439,12 +439,93 @@ def test_plot_state_hist(mock_plt): os.remove('hist_data.npy') -def test_calculate_hist_rmse(): - pass +def test_calc_hist_rmse(): + # Case 1: Exactly flat histogram with some states acceessible by 2 replicas + hist_data = [[15, 15, 30, 30, 15, 15], [15, 15, 30, 30, 15, 15]] + state_ranges = [[0, 1, 2, 3], [2, 3, 4, 5]] + assert analyze_traj.calc_hist_rmse(hist_data, state_ranges) == 0 + # Case 2: Exactly flat histogram with some states acceessible by 3 replicas + hist_data = [[10, 20, 30, 30, 20, 10], [10, 20, 30, 30, 20, 10]] + state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]] + assert analyze_traj.calc_hist_rmse(hist_data, state_ranges) == 0 -def plot_transit_time(): - pass + +@patch('ensemble_md.analysis.analyze_traj.plt') +def test_plot_transit_time(mock_plt): + N = 4 + trajs = [ + [0, 1, 2, 0, 2, 3, 2, 2, 1, 0, 1, 1, 2, 0, 1, 2, 3, 2, 1, 0], + [1, 2, 1, 0, 1, 2, 2, 3, 2, 3, 3, 2, 1, 0, 1, 2, 2, 3, 2, 1] + ] + + # Case 1: Default settings + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N) + assert t_1 == [[5, 7], [4, 4]] + assert t_2 == [[4, 3], [6]] + assert t_3 == [[9, 10], [10]] + assert u == 'step' + + mock_plt.figure.assert_called() + mock_plt.plot.assert_called() + mock_plt.xlabel.assert_called_with('Event index') + mock_plt.ylabel.assert_called() + + assert mock_plt.figure.call_count == 3 + assert mock_plt.plot.call_count == 6 + assert mock_plt.xlabel.call_count == 3 + assert mock_plt.ylabel.call_count == 3 + assert mock_plt.grid.call_count == 3 + assert mock_plt.legend.call_count == 3 + + np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0], [[1, 2], [5, 7]]) + np.testing.assert_array_equal(mock_plt.plot.call_args_list[1][0], [[1, 2], [4, 4]]) + np.testing.assert_array_equal(mock_plt.plot.call_args_list[2][0], [[1, 2], [4, 3]]) + np.testing.assert_array_equal(mock_plt.plot.call_args_list[3][0], [[1], [6]]) + np.testing.assert_array_equal(mock_plt.plot.call_args_list[4][0], [[1, 2], [9, 10]]) + np.testing.assert_array_equal(mock_plt.plot.call_args_list[5][0], [[1], [10]]) + assert mock_plt.plot.call_args_list[0][1] == {'label': 'Trajectory 0', 'marker': 'o'} + assert mock_plt.plot.call_args_list[1][1] == {'label': 'Trajectory 1', 'marker': 'o'} + assert mock_plt.plot.call_args_list[2][1] == {'label': 'Trajectory 0', 'marker': 'o'} + assert mock_plt.plot.call_args_list[3][1] == {'label': 'Trajectory 1', 'marker': 'o'} + assert mock_plt.plot.call_args_list[4][1] == {'label': 'Trajectory 0', 'marker': 'o'} + assert mock_plt.plot.call_args_list[5][1] == {'label': 'Trajectory 1', 'marker': 'o'} + assert mock_plt.ylabel.call_args_list[0][0] == ('Average transit time from states 0 to k (step)',) + assert mock_plt.ylabel.call_args_list[1][0] == ('Average transit time from states k to 0 (step)',) + assert mock_plt.ylabel.call_args_list[2][0] == ('Average round-trip time (step)',) + + # Case 2: dt = 0.2 ps, fig_prefix = 'test', here we just test the return values + mock_plt.reset_mock() + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2) + t_1_, t_2_, t_3_ = [[1.0, 1.4], [0.8, 0.8]], [[0.8, 0.6], [1.2]], [[1.8, 2.0], [2.0]] + for i in range(2): + np.testing.assert_array_almost_equal(t_1[i], t_1_[i]) + np.testing.assert_array_almost_equal(t_2[i], t_2_[i]) + np.testing.assert_array_almost_equal(t_3[i], t_3_[i]) + assert u == 'ps' + + # Case 3: dt = 200 ps, long trajs + mock_plt.reset_mock() + trajs = np.ones((2, 2000000), dtype=int) + trajs[0][0], trajs[0][1000000], trajs[0][1999999] = 0, 3, 0 + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=200) + assert t_1 == [[200000.0], []] + assert t_2 == [[199999.8], []] + assert t_3 == [[399999.8], []] + assert u == 'ns' + mock_plt.ticklabel_format.assert_called_with(style='sci', axis='y', scilimits=(0, 0)) + assert mock_plt.ticklabel_format.call_count == 3 + + # Case 4: Poor sampling + mock_plt.reset_mock() + trajs = [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1]] + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N) + assert t_1 == [[], []] + assert t_2 == [[], []] + assert t_3 == [[], []] + assert u == 'step' + mock_plt.figure.assert_not_called() + mock_plt.savefig.assert_not_called() def test_plot_g_vecs():