Skip to content

Commit

Permalink
Added a test for calc_hist_rmse and plot_transit_time
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 31, 2024
1 parent fcb8787 commit 4610175
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
20 changes: 12 additions & 8 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
"""
The :obj:`.analyze_traj` module provides methods for analyzing trajectories in REXEE.
"""
import copy
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain
from matplotlib.ticker import MaxNLocator

from alchemlyb.parsing.gmx import _get_headers as get_headers
Expand Down Expand Up @@ -511,7 +513,6 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
y_max = 0
for i in range(n_configs):
max_count = np.max(bottom + hist_data[i])
print(max_count)
if max_count > y_max:
y_max = max_count
plt.bar(
Expand Down Expand Up @@ -678,12 +679,13 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
last_visited = k

# Here we figure out the round-trip time from t_0k and t_k0.
if len(t_0k) != len(t_k0): # then it must be len(t_0k) = len(t_k0) + 1 or len(t_k0) = len(t_0k) + 1, so we drop the last element of the larger list # noqa: E501
if len(t_0k) > len(t_k0):
t_0k.pop()
t_0k_, t_k0_ = copy.deepcopy(t_0k), copy.deepcopy(t_k0)
if len(t_0k_) != len(t_k0_): # then it must be len(t_0k) = len(t_k0) + 1 or len(t_k0) = len(t_0k) + 1, so we drop the last element of the larger list # noqa: E501
if len(t_0k_) > len(t_k0_):
t_0k_.pop()
else:
t_k0.pop()
t_roundtrip = list(np.array(t_0k) + np.array(t_k0))
t_k0_.pop()
t_roundtrip = list(np.array(t_0k_) + np.array(t_k0_))

if end_0_found is True and end_k_found is True:
if dt is not None:
Expand Down Expand Up @@ -711,7 +713,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
t_roundtrip_avg.append(np.mean(t_roundtrip))

if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty
if sci is False and np.max([t_0k, t_k0, t_roundtrip]) >= 10000:
flattened_list = list(chain.from_iterable([t_0k, t_k0, t_roundtrip]))
if sci is False and np.max(flattened_list) >= 10000:
sci = True
else:
t_0k_list.append([])
Expand Down Expand Up @@ -742,7 +745,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
for i in range(len(t_list)): # t_list[i] is the list for trajectory i
plt.plot(np.arange(len(t_list[i])) + 1, t_list[i], label=f'Trajectory {i}', marker=marker)

if max(max((t_list))) >= 10000:
flattened_t_list = list(chain.from_iterable(t_list))
if np.max(flattened_t_list) >= 10000:
plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
plt.xlabel('Event index')
plt.ylabel(f'{y_labels[t]}')
Expand Down
89 changes: 85 additions & 4 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,93 @@ def test_plot_state_hist(mock_plt):
os.remove('hist_data.npy')


def test_calculate_hist_rmse():
pass
def test_calc_hist_rmse():
# Case 1: Exactly flat histogram with some states acceessible by 2 replicas
hist_data = [[15, 15, 30, 30, 15, 15], [15, 15, 30, 30, 15, 15]]
state_ranges = [[0, 1, 2, 3], [2, 3, 4, 5]]
assert analyze_traj.calc_hist_rmse(hist_data, state_ranges) == 0

# Case 2: Exactly flat histogram with some states acceessible by 3 replicas
hist_data = [[10, 20, 30, 30, 20, 10], [10, 20, 30, 30, 20, 10]]
state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]]
assert analyze_traj.calc_hist_rmse(hist_data, state_ranges) == 0

def plot_transit_time():
pass

@patch('ensemble_md.analysis.analyze_traj.plt')
def test_plot_transit_time(mock_plt):
N = 4
trajs = [
[0, 1, 2, 0, 2, 3, 2, 2, 1, 0, 1, 1, 2, 0, 1, 2, 3, 2, 1, 0],
[1, 2, 1, 0, 1, 2, 2, 3, 2, 3, 3, 2, 1, 0, 1, 2, 2, 3, 2, 1]
]

# Case 1: Default settings
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N)
assert t_1 == [[5, 7], [4, 4]]
assert t_2 == [[4, 3], [6]]
assert t_3 == [[9, 10], [10]]
assert u == 'step'

mock_plt.figure.assert_called()
mock_plt.plot.assert_called()
mock_plt.xlabel.assert_called_with('Event index')
mock_plt.ylabel.assert_called()

assert mock_plt.figure.call_count == 3
assert mock_plt.plot.call_count == 6
assert mock_plt.xlabel.call_count == 3
assert mock_plt.ylabel.call_count == 3
assert mock_plt.grid.call_count == 3
assert mock_plt.legend.call_count == 3

np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0], [[1, 2], [5, 7]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[1][0], [[1, 2], [4, 4]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[2][0], [[1, 2], [4, 3]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[3][0], [[1], [6]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[4][0], [[1, 2], [9, 10]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[5][0], [[1], [10]])
assert mock_plt.plot.call_args_list[0][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[1][1] == {'label': 'Trajectory 1', 'marker': 'o'}
assert mock_plt.plot.call_args_list[2][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[3][1] == {'label': 'Trajectory 1', 'marker': 'o'}
assert mock_plt.plot.call_args_list[4][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[5][1] == {'label': 'Trajectory 1', 'marker': 'o'}
assert mock_plt.ylabel.call_args_list[0][0] == ('Average transit time from states 0 to k (step)',)
assert mock_plt.ylabel.call_args_list[1][0] == ('Average transit time from states k to 0 (step)',)
assert mock_plt.ylabel.call_args_list[2][0] == ('Average round-trip time (step)',)

# Case 2: dt = 0.2 ps, fig_prefix = 'test', here we just test the return values
mock_plt.reset_mock()
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2)
t_1_, t_2_, t_3_ = [[1.0, 1.4], [0.8, 0.8]], [[0.8, 0.6], [1.2]], [[1.8, 2.0], [2.0]]
for i in range(2):
np.testing.assert_array_almost_equal(t_1[i], t_1_[i])
np.testing.assert_array_almost_equal(t_2[i], t_2_[i])
np.testing.assert_array_almost_equal(t_3[i], t_3_[i])
assert u == 'ps'

# Case 3: dt = 200 ps, long trajs
mock_plt.reset_mock()
trajs = np.ones((2, 2000000), dtype=int)
trajs[0][0], trajs[0][1000000], trajs[0][1999999] = 0, 3, 0
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=200)
assert t_1 == [[200000.0], []]
assert t_2 == [[199999.8], []]
assert t_3 == [[399999.8], []]
assert u == 'ns'
mock_plt.ticklabel_format.assert_called_with(style='sci', axis='y', scilimits=(0, 0))
assert mock_plt.ticklabel_format.call_count == 3

# Case 4: Poor sampling
mock_plt.reset_mock()
trajs = [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1]]
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N)
assert t_1 == [[], []]
assert t_2 == [[], []]
assert t_3 == [[], []]
assert u == 'step'
mock_plt.figure.assert_not_called()
mock_plt.savefig.assert_not_called()


def test_plot_g_vecs():
Expand Down

0 comments on commit 4610175

Please sign in to comment.