Skip to content

Commit

Permalink
Refined some previously written tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 6, 2024
1 parent f683e57 commit 7d6320a
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 18 deletions.
6 changes: 3 additions & 3 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
plt.grid()
plt.legend()
if fig_prefix is None:
plt.savefig(f'{folder}/{fig_names[t]}')
plt.savefig(f'{folder}/{fig_names[t]}', dpi=600)
else:
plt.savefig(f'{folder}/{fig_prefix}_{fig_names[t]}', dpi=600)

Expand Down Expand Up @@ -800,9 +800,9 @@ def plot_g_vecs(g_vecs, refs=None, refs_err=None, plot_rmse=True):
"""
# n_iter, n_state = g_vecs.shape[0], g_vecs.shape[1]
g_vecs = np.transpose(g_vecs)
n_sim = len(g_vecs)
n_states = len(g_vecs)
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_sim) / n_sim]
colors = [cmap(i) for i in np.arange(n_states) / n_states]
plt.figure()
for i in range(1, len(g_vecs)):
if len(g_vecs[0]) < 100:
Expand Down
11 changes: 10 additions & 1 deletion ensemble_md/analysis/synthesize_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,19 @@ def synthesize_traj(trans_mtx, n_frames=100000, method='transmtx', start=0, seed
equil_prob = analyze_traj.calc_equil_prob(trans_mtx)
syn_traj = np.random.choice(N, size=n_frames, p=equil_prob.reshape(N))
elif method == 'transmtx':
check_row = sum([np.isclose(np.sum(trans_mtx[i]), 1, atol=1e-8) for i in range(len(trans_mtx))])
check_col = sum([np.isclose(np.sum(trans_mtx[:, i]), 1, atol=1e-8) for i in range(len(trans_mtx))])
if check_row == N:
mtx = trans_mtx
elif check_col == N:
mtx = trans_mtx.T
else:
raise ValueError('The input matrix is not normalized')

syn_traj = np.zeros(n_frames, dtype=int)
syn_traj[0] = start
for i in range(1, n_frames):
syn_traj[i] = np.random.choice(N, p=trans_mtx[syn_traj[i-1]])
syn_traj[i] = np.random.choice(N, p=mtx[syn_traj[i-1]])
else:
raise ValueError(f'Invalid method: {method}. The method must be either "transmtx" or "equil_prob".')

Expand Down
8 changes: 4 additions & 4 deletions ensemble_md/tests/test_analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,24 @@ def test_calc_equil_prob(capfd):
def test_calc_spectral_gap(capfd):
# Case 1 (sanity check): doublly stochastic
mtx = np.array([[0.5, 0.5], [0.5, 0.5]])
s, vals = analyze_matrix.calc_spectral_gap(mtx)
s, err, vals = analyze_matrix.calc_spectral_gap(mtx, n_bootstrap=5)
assert vals[0] == 1
assert np.isclose(s, 1)

# Case 2: Right stochastic
mtx = np.array([[0.8, 0.2], [0.3, 0.7]])
s, vals = analyze_matrix.calc_spectral_gap(mtx)
s, err, vals = analyze_matrix.calc_spectral_gap(mtx, n_bootstrap=5)
assert vals[0] == 1
assert s == 0.5

# Case 3: Left stochastic
s, vals = analyze_matrix.calc_spectral_gap(mtx.T)
s, err, vals = analyze_matrix.calc_spectral_gap(mtx.T, n_bootstrap=5)
assert vals[0] == 1
assert s == 0.5

# Case 4: Neither left or right stochastic
mtx = np.random.rand(3, 3)
s = analyze_matrix.calc_spectral_gap(mtx) # the output should be None
s = analyze_matrix.calc_spectral_gap(mtx, n_bootstrap=5) # the output should be None
out, err = capfd.readouterr()
assert s is None
assert 'The input transition matrix is neither right nor left stochastic' in out
Expand Down
139 changes: 129 additions & 10 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ def test_convert_npy2xvg():
assert content[2] == '0.0 4\n'
assert content[3] == '0.2 6\n'

os.remove('traj_0.xvg')
os.remove('traj_1.xvg')

trajs = np.array([[0.0, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6, 0.7]])
analyze_traj.convert_npy2xvg(trajs, dt, subsampling)

assert os.path.exists('traj_0.xvg')
assert os.path.exists('traj_1.xvg')

with open('traj_0.xvg', 'r') as f:
content = f.readlines()
assert content[0] == '# This file was created by ensemble_md\n'
assert content[1] == '# Time (ps) v.s. CV\n'
assert content[2] == '0.0 0.000000\n'
assert content[3] == '0.2 0.200000\n'

with open('traj_1.xvg', 'r') as f:
content = f.readlines()
assert content[0] == '# This file was created by ensemble_md\n'
assert content[1] == '# Time (ps) v.s. CV\n'
assert content[2] == '0.0 0.400000\n'
assert content[3] == '0.2 0.600000\n'

os.remove('traj_0.xvg')
os.remove('traj_1.xvg')
os.chdir('../../../')
Expand Down Expand Up @@ -484,12 +507,16 @@ def test_plot_transit_time(mock_plt):
np.testing.assert_array_equal(mock_plt.plot.call_args_list[3][0], [[1], [6]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[4][0], [[1, 2], [9, 10]])
np.testing.assert_array_equal(mock_plt.plot.call_args_list[5][0], [[1], [10]])
assert mock_plt.plot.call_args_list[0][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[1][1] == {'label': 'Trajectory 1', 'marker': 'o'}
assert mock_plt.plot.call_args_list[2][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[3][1] == {'label': 'Trajectory 1', 'marker': 'o'}
assert mock_plt.plot.call_args_list[4][1] == {'label': 'Trajectory 0', 'marker': 'o'}
assert mock_plt.plot.call_args_list[5][1] == {'label': 'Trajectory 1', 'marker': 'o'}

assert [mock_plt.plot.call_args_list[i][1] for i in range(6)] == [
{'label': 'Trajectory 0', 'marker': 'o'},
{'label': 'Trajectory 1', 'marker': 'o'},
{'label': 'Trajectory 0', 'marker': 'o'},
{'label': 'Trajectory 1', 'marker': 'o'},
{'label': 'Trajectory 0', 'marker': 'o'},
{'label': 'Trajectory 1', 'marker': 'o'}
]

assert mock_plt.ylabel.call_args_list[0][0] == ('Average transit time from states 0 to k (step)',)
assert mock_plt.ylabel.call_args_list[1][0] == ('Average transit time from states k to 0 (step)',)
assert mock_plt.ylabel.call_args_list[2][0] == ('Average round-trip time (step)',)
Expand Down Expand Up @@ -527,12 +554,78 @@ def test_plot_transit_time(mock_plt):
mock_plt.figure.assert_not_called()
mock_plt.savefig.assert_not_called()

# Case 5: More than 100 round trips so that a histogram is plotted
mock_plt.reset_mock()
trajs = np.array([[0, 1, 2, 3, 2] * 20000, [0, 1, 3, 2, 1] * 20000])
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N)

assert t_1 == [[3] * 20000, [2] * 20000]
assert t_2 == [[2] * 19999, [3] * 19999]
assert t_3 == [[5] * 19999, [5] * 19999]
assert u == 'step'

mock_plt.hist.assert_called()
mock_plt.ticklabel_format.assert_called_with(style='sci', axis='y', scilimits=(0, 0))

assert mock_plt.figure.call_count == 6
assert mock_plt.hist.call_count == 6
assert mock_plt.xlabel.call_count == 6
assert mock_plt.ylabel.call_count == 6
assert mock_plt.ticklabel_format.call_count == 6
assert mock_plt.grid.call_count == 6
assert mock_plt.legend.call_count == 6
assert mock_plt.savefig.call_count == 6

assert mock_plt.hist.call_args_list[0][0][0] == [3] * 20000
assert mock_plt.hist.call_args_list[1][0][0] == [2] * 20000
assert mock_plt.hist.call_args_list[2][0][0] == [2] * 19999
assert mock_plt.hist.call_args_list[3][0][0] == [3] * 19999
assert mock_plt.hist.call_args_list[4][0][0] == [5] * 19999
assert mock_plt.hist.call_args_list[5][0][0] == [5] * 19999

assert [mock_plt.hist.call_args_list[i][1] for i in range(6)] == [
{'bins': 1000, 'label': 'Trajectory 0'},
{'bins': 1000, 'label': 'Trajectory 1'},
{'bins': 999, 'label': 'Trajectory 0'},
{'bins': 999, 'label': 'Trajectory 1'},
{'bins': 999, 'label': 'Trajectory 0'},
{'bins': 999, 'label': 'Trajectory 1'}
]

assert [mock_plt.xlabel.call_args_list[i][0][0] for i in range(6)] == [
'Event index',
'Average transit time from states 0 to k (step)',
'Event index',
'Average transit time from states k to 0 (step)',
'Event index',
'Average round-trip time (step)'
]

assert [mock_plt.ylabel.call_args_list[i][0][0] for i in range(6)] == [
'Average transit time from states 0 to k (step)',
'Event count',
'Average transit time from states k to 0 (step)',
'Event count',
'Average round-trip time (step)',
'Event count'
]

assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(6)] == [
'./t_0k.png',
'./hist_t_0k.png',
'./t_k0.png',
'./hist_t_k0.png',
'./t_roundtrip.png',
'./hist_t_roundtrip.png'
]


@patch('ensemble_md.analysis.analyze_traj.plt')
def test_plot_g_vecs(mock_plt):
# cmap = mock_plt.cm.ocean
cmap = mock_plt.cm.ocean
mock_ax = MagicMock()
mock_plt.gca.return_value = mock_ax
colors = [cmap(i) for i in np.arange(4) / 4]

# Case 1: Short g_vecs with refs and with plot_rmse = True
g_vecs = np.array([[0, 10, 20, 30], [0, 8, 18, 28]])
Expand All @@ -544,21 +637,47 @@ def test_plot_g_vecs(mock_plt):
mock_plt.figure.assert_called()
mock_plt.plot.assert_called()
mock_plt.xlabel.assert_called_with('Iteration index')
# mock_plt.ylabel.assert_called_any('Alchemical weight (kT)')
mock_plt.xlim.assert_called()
mock_plt.grid.assert_called()
mock_plt.legend.assert_called_with(loc='center left', bbox_to_anchor=(1, 0.2))
mock_plt.legend.assert_called_once_with(loc='center left', bbox_to_anchor=(1, 0.2))
mock_plt.xlabel.assert_called_with('Iteration index')

assert mock_plt.figure.call_count == 2
assert mock_plt.plot.call_count == 4
assert mock_plt.axhline.call_count == 3
assert mock_plt.fill_between.call_count == 3
assert mock_plt.grid.call_count == 2
assert mock_plt.xlabel.call_count == 2
assert mock_plt.ylabel.call_count == 2

assert [mock_plt.plot.call_args_list[i][0][0] for i in range(4)] == [range(2)] * 4
np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0][1], np.array([10, 8]))
np.testing.assert_array_equal(mock_plt.plot.call_args_list[1][0][1], np.array([20, 18]))
np.testing.assert_array_equal(mock_plt.plot.call_args_list[2][0][1], np.array([30, 28]))
np.testing.assert_array_equal(mock_plt.plot.call_args_list[3][0][1], np.array([np.sqrt(3), 0])) # RMSE as a function the iteration index # noqa: E501

assert mock_plt.plot.call_args_list[0][1] == {'label': 'State 1', 'marker': 'o', 'c': colors[0], 'linewidth': 0.8, 'markersize': 2} # noqa: E501
assert mock_plt.plot.call_args_list[1][1] == {'label': 'State 2', 'marker': 'o', 'c': colors[1], 'linewidth': 0.8, 'markersize': 2} # noqa: E501
assert mock_plt.plot.call_args_list[2][1] == {'label': 'State 3', 'marker': 'o', 'c': colors[2], 'linewidth': 0.8, 'markersize': 2} # noqa: E501

assert mock_plt.ylabel.call_args_list[0][0] == ('Alchemical weight (kT)',)
assert mock_plt.ylabel.call_args_list[1][0] == ('RMSE in the alchemical weights (kT)',)

# Case 2: Long g_vecs
# Case 2: Long g_vecs, here we just check the only different line
mock_plt.reset_mock()
g_vecs = np.array([[0, 10, 20, 30]] * 200)
analyze_traj.plot_g_vecs(g_vecs)

assert mock_plt.plot.call_count == 3
assert [mock_plt.plot.call_args_list[i][0][0] for i in range(3)] == [range(200)] * 3

np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0][1], np.array([10] * 200))
np.testing.assert_array_equal(mock_plt.plot.call_args_list[1][0][1], np.array([20] * 200))
np.testing.assert_array_equal(mock_plt.plot.call_args_list[2][0][1], np.array([30] * 200))

assert mock_plt.plot.call_args_list[0][1] == {'label': 'State 1', 'c': colors[0], 'linewidth': 0.8}
assert mock_plt.plot.call_args_list[1][1] == {'label': 'State 2', 'c': colors[1], 'linewidth': 0.8}
assert mock_plt.plot.call_args_list[2][1] == {'label': 'State 3', 'c': colors[2], 'linewidth': 0.8}


def test_get_swaps():
Expand Down

0 comments on commit 7d6320a

Please sign in to comment.