Skip to content

Commit

Permalink
Further pushed the code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 22, 2024
1 parent 84558ca commit 5175648
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 43 deletions.
42 changes: 39 additions & 3 deletions ensemble_md/tests/test_analyze_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
@patch('ensemble_md.analysis.analyze_free_energy.extract_dHdl')
@patch('ensemble_md.analysis.analyze_free_energy.detect_equilibration')
@patch('ensemble_md.analysis.analyze_free_energy.subsample_correlated_data')
def test_preprocess_data(mock_corr, mock_equil, mock_extract_dHdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501
def test_preprocess_data(mock_corr, mock_equil, mock_extract_dhdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501
mock_data, mock_data_series = MagicMock(), MagicMock()
mock_alchemlyb.concat.return_value = mock_data
mock_subsampling.u_nk2series.return_value = mock_data_series
mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series)
mock_equil.return_value = (10, 5, 50) # t, g, Neff_max
mock_equil.return_value = (10, 5, 18) # t, g, Neff_max
mock_data_series.__len__.return_value = 100 # For one of the print statements

# Set slicing to return different mock objects based on input
Expand All @@ -53,6 +53,7 @@ def generic_list_slicing(key):
mock_data_series.__getitem__.side_effect = slicing_side_effect # so that we can use mock_data_series[t:]
mock_data_series_equil = mock_data_series[10:] # Mock the equilibrated data series, given t=10

# Case 1: data_type = u_nk
files = [[f'ensemble_md/tests/data/dhdl/simulation_example/sim_{i}/iteration_{j}/dhdl.xvg' for j in range(3)] for i in range(4)] # noqa: E501
results = analyze_free_energy.preprocess_data(files, 300, 'u_nk')

Expand All @@ -74,13 +75,48 @@ def generic_list_slicing(key):
assert ' Adopted spacing: 1' in out
assert ' 10.0% of the u_nk data was in the equilibrium region and therfore discarded.' in out # noqa: E501
assert ' Statistical inefficiency of u_nk: 5.0' in out
assert ' Number of effective samples: 50' in out
assert ' Number of effective samples: 18' in out
assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5)

assert len(results[0]) == 4
assert results[1] == [10, 10, 10, 10]
assert results[2] == [5, 5, 5, 5]

# Case 2: data_type = dHdl
mock_alchemlyb.concat.reset_mock()
mock_subsampling._prepare_input.reset_mock()
mock_subsampling.slicing.reset_mock()
mock_equil.reset_mock()

mock_subsampling.dhdl2series.return_value = mock_data_series
mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series)
mock_data_series.__len__.return_value = 200
mock_data_series.values.__len__.return_value = 200

results = analyze_free_energy.preprocess_data(files, 300, 'dhdl', t=10, g=5)
out, err = capfd.readouterr()

for i in range(4):
for j in range(3):
assert mock_extract_dhdl.call_args_list[i * 3 + j] == call(files[i][j], T=300)
assert mock_subsampling._prepare_input.call_args_list[i] == call(mock_data, mock_data_series, drop_duplicates=True, sort=True) # noqa: E501
assert mock_subsampling.slicing.call_args_list[2 * i] == call(mock_data, step=1)
assert mock_subsampling.slicing.call_args_list[2 * i + 1] == call(mock_data_series, step=1)
assert 'Subsampling and decorrelating the concatenated dhdl data ...' in out
assert ' Adopted spacing: 1' in out
assert ' 5.0% of the dhdl data was in the equilibrium region and therfore discarded.' in out # noqa: E501
assert ' Statistical inefficiency of dhdl: 5.0' in out
assert ' Number of effective samples: 38' in out
assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5)

assert len(results[0]) == 4
assert results[1] == []
assert results[2] == []

# Case 3: Invalid data_type
with pytest.raises(ValueError, match="Invalid data_type. Expected 'u_nk' or 'dhdl'."):
analyze_free_energy.preprocess_data(files, 300, 'xyz')


@pytest.mark.parametrize("method, expected_estimator", [
("MBAR", "MBAR estimator"),
Expand Down
16 changes: 16 additions & 0 deletions ensemble_md/tests/test_analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def test_calc_transmtx():
assert B3 is None
assert C3 is None

# Case 4: Invalid simulation type
with pytest.raises(ValueError, match='Invalid simulation type test.'):
analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/EXE.log'), simulation_type='test')


def test_calc_equil_prob(capfd):
# Case 1: Right stochastic
Expand All @@ -87,6 +91,18 @@ def test_calc_equil_prob(capfd):
assert 'The input transition matrix is neither right nor left stochastic' in out


def test_calc_t_relax():
# Case 1: spectral_gap_err is specified
results = analyze_matrix.calc_t_relax(0.5, 0.1, 0.1)
assert results[0] == 0.2
assert results[1] == 0.1 * 0.1 / 0.5 ** 2

# Case 2: spectral_gap_err is not specified
results = analyze_matrix.calc_t_relax(0.5, 0.1)
assert results[0] == 0.2
assert results[1] is None


def test_calc_spectral_gap(capfd):
# Case 1 (sanity check): doublly stochastic
mtx = np.array([[0.5, 0.5], [0.5, 0.5]])
Expand Down
59 changes: 50 additions & 9 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,38 @@ def test_stitch_time_series_for_sim():
assert os.path.exists('state_trajs_for_sim.npy')
os.remove('state_trajs_for_sim.npy')

# Test 2: Test for discontinuous time series
# Test 2: The case where dhdl is False
# Here we again use dhdl.xvg files but use dhdl=False with col_idx=1, which corresponds to the state index
trajs = analyze_traj.stitch_time_series_for_sim(files, dhdl=False, col_idx=1)

trajs[0] == [
0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 1, 0, 1, 1,
1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1
]

trajs[1] == [
1, 1, 2, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1,
2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 2, 3, 3, 3, 2, 2,
1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1
]

trajs[2] == [
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3,
3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 2, 3, 4, 3, 3, 2,
3, 3, 2, 2, 2, 3, 4, 3, 4, 4, 5, 5, 5, 5, 4, 3, 4, 3, 3, 4, 4
]

trajs[3] == [
3, 3, 3, 3, 3, 3, 3, 5, 4, 4, 5, 4, 4, 5, 4, 5, 5, 5, 4, 5,
4, 4, 5, 4, 5, 5, 4, 5, 5, 5, 4, 5, 5, 4, 5, 4, 5, 4, 5, 5,
6, 6, 6, 5, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 6, 7, 6, 7
]

assert os.path.exists('state_trajs_for_sim.npy')
os.remove('state_trajs_for_sim.npy')

# Test 3: Test for discontinuous time series
# Here, for sim_2, we exclude the last 5 lines for the dhdl.xvg file in iteration_1 to create a gap
save_and_exclude(f'{folder}/sim_2/iteration_1/dhdl.xvg', 5)
os.rename(f'{folder}/sim_2/iteration_1/dhdl.xvg', f'{folder}/sim_2/iteration_1/dhdl_temp.xvg')
Expand Down Expand Up @@ -624,16 +655,26 @@ def test_plot_transit_time(mock_plt):
assert mock_plt.ylabel.call_args_list[0] == call('Average transit time from states 0 to k (step)')
assert mock_plt.ylabel.call_args_list[1] == call('Average transit time from states k to 0 (step)')
assert mock_plt.ylabel.call_args_list[2] == call('Average round-trip time (step)')
assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [
'./t_0k.png',
'./t_k0.png',
'./t_roundtrip.png',
]

# Case 2: dt = 0.2 ps, fig_prefix = 'test', here we just test the return values
mock_plt.reset_mock()
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2)
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2, fig_prefix='test')
t_1_, t_2_, t_3_ = [[1.0, 1.4], [0.8, 0.8]], [[0.8, 0.6], [1.2]], [[1.8, 2.0], [2.0]]
for i in range(2):
np.testing.assert_array_almost_equal(t_1[i], t_1_[i])
np.testing.assert_array_almost_equal(t_2[i], t_2_[i])
np.testing.assert_array_almost_equal(t_3[i], t_3_[i])
assert u == 'ps'
assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [
'./test_t_0k.png',
'./test_t_k0.png',
'./test_t_roundtrip.png',
]

# Case 3: dt = 200 ps, long trajs
mock_plt.reset_mock()
Expand Down Expand Up @@ -661,7 +702,7 @@ def test_plot_transit_time(mock_plt):
# Case 5: More than 100 round trips so that a histogram is plotted
mock_plt.reset_mock()
trajs = np.array([[0, 1, 2, 3, 2] * 20000, [0, 1, 3, 2, 1] * 20000])
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N)
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, fig_prefix='test')

assert t_1 == [[3] * 20000, [2] * 20000]
assert t_2 == [[2] * 19999, [3] * 19999]
Expand Down Expand Up @@ -715,12 +756,12 @@ def test_plot_transit_time(mock_plt):
]

assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(6)] == [
'./t_0k.png',
'./hist_t_0k.png',
'./t_k0.png',
'./hist_t_k0.png',
'./t_roundtrip.png',
'./hist_t_roundtrip.png'
'./test_t_0k.png',
'./test_hist_t_0k.png',
'./test_t_k0.png',
'./test_hist_t_k0.png',
'./test_t_roundtrip.png',
'./test_hist_t_roundtrip.png'
]


Expand Down
14 changes: 14 additions & 0 deletions ensemble_md/tests/test_gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ def test_write(self):
mdp = gmx_parser.MDP("ensemble_md/tests/data/expanded.mdp")
mdp.write('test_1.mdp', skipempty=False)
mdp.write('test_2.mdp', skipempty=True)

assert os.path.isfile('test_1.mdp')
assert os.path.isfile('test_2.mdp')

mdp = gmx_parser.MDP('test_1.mdp')
mdp.write(skipempty=True) # This should overwrite the file

# Check if the files are the same
with open('test_1.mdp', 'r') as f:
lines_1 = f.readlines()
with open('test_2.mdp', 'r') as f:
lines_2 = f.readlines()
assert lines_1 == lines_2

os.remove('test_1.mdp')
os.remove('test_2.mdp')

Expand Down
60 changes: 30 additions & 30 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,17 @@ def test_set_params_warnings(self, params_dict):

params_dict['mdp'] = 'ensemble_md/tests/data/expanded_test.mdp'
params_dict['N_cutoff'] = 1000
REXEE_1 = get_REXEE_instance(params_dict)
REXEE = get_REXEE_instance(params_dict)

warning_1 = 'Warning: The weight correction/weight combination method is specified but will not be used since the weights are fixed.' # noqa: E501
warning_2 = 'Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.'
warning_3 = 'Warning: We recommend setting gen_seed as -1 so the random seed is different for each iteration.'
warning_4 = 'Warning: We recommend generating new velocities for each iteration to avoid potential issues with the detailed balance.' # noqa: E501

assert warning_1 in REXEE_1.warnings
assert warning_2 in REXEE_1.warnings
assert warning_3 in REXEE_1.warnings
assert warning_4 in REXEE_1.warnings
assert warning_1 in REXEE.warnings
assert warning_2 in REXEE.warnings
assert warning_3 in REXEE.warnings
assert warning_4 in REXEE.warnings

os.remove(os.path.join(input_path, "expanded_test.mdp"))

Expand Down Expand Up @@ -426,7 +426,7 @@ def test_print_params(self, capfd, params_dict):
# capfd is a fixture in pytest for testing STDOUT
REXEE = get_REXEE_instance(params_dict)
REXEE.print_params()
out_1, err = capfd.readouterr()
out, err = capfd.readouterr()
L = ""
L += "Important parameters of REXEE\n=============================\n"
L += f"Python version: {sys.version}\n"
Expand All @@ -451,11 +451,11 @@ def test_print_params(self, capfd, params_dict):
L += "Alchemical ranges of each replica in REXEE:\n - Replica 0: States [0, 1, 2, 3, 4, 5]\n"
L += " - Replica 1: States [1, 2, 3, 4, 5, 6]\n - Replica 2: States [2, 3, 4, 5, 6, 7]\n"
L += " - Replica 3: States [3, 4, 5, 6, 7, 8]\n"
assert out_1 == L
assert out == L

REXEE.reformatted_mdp = True # Just to test the case where REXEE.reformatted_mdp is True
REXEE.print_params(params_analysis=True)
out_2, err = capfd.readouterr()
out, err = capfd.readouterr()
L += "\nWhether to build Markov state models and perform relevant analysis: False\n"
L += "Whether to perform free energy calculations: False\n"
L += "The step to used in subsampling the DHDL data in free energy calculations, if any: 1\n"
Expand All @@ -464,7 +464,7 @@ def test_print_params(self, capfd, params_dict):
L += "The number of bootstrap iterations in the boostrapping method, if used: 50\n"
L += "The random seed to use in bootstrapping, if used: None\n"
L += "Note that the input MDP file has been reformatted by replacing hypens with underscores. The original mdp file has been renamed as *backup.mdp.\n" # noqa: E501
assert out_2 == L
assert out == L

REXEE.gro = ['ensemble_md/tests/data/sys.gro', 'ensemble_md/tests/data/sys.gro'] # noqa: E501
REXEE.top = ['ensemble_md/tests/data/sys.top', 'ensemble_md/tests/data/sys.top']
Expand Down Expand Up @@ -642,20 +642,20 @@ def test_identify_swappable_pairs(self, params_dict):
states = [4, 2, 2, 7] # This would lead to the swappables: [(0, 1), (0, 2), (1, 2)] in the standard case

# Case 1: Any case that is not neighboring swap has the same definition for the swappable pairs
swappables_1 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring')
assert swappables_1 == [(0, 1), (0, 2), (1, 2)]
swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring')
assert swappables == [(0, 1), (0, 2), (1, 2)]

# Case 2: Neighboring exchange
REXEE.proposal = 'neighboring'
swappables_2 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring')
assert swappables_2 == [(0, 1), (1, 2)]
swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring')
assert swappables == [(0, 1), (1, 2)]

# Case 3: Non-neighboring exchange, with add_swappables
REXEE.proposal = 'exhaustive'
REXEE.add_swappables = [[3, 7], [4, 7]]
states = [4, 3, 2, 7] # Without add_swappables, the swappables would be [(0, 1), (0, 2), (1, 2)]
swappables_3 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring', REXEE.add_swappables) # noqa: E501
assert swappables_3 == [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3)]
swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring', REXEE.add_swappables) # noqa: E501
assert swappables == [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3)]

def test_propose_swap(self, params_dict):
random.seed(0)
Expand Down Expand Up @@ -777,18 +777,18 @@ def test_calc_prob_acc(self, capfd, params_dict):

# Test 1
swap = (0, 1)
prob_acc_1 = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts)
prob_acc = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts)
out, err = capfd.readouterr()
# dU = (-9.1366697 + 11.0623788)/2.4777098766670016 ~ 0.7772 kT, so p_acc = 0.45968522728859024
assert prob_acc_1 == pytest.approx(0.45968522728859024)
assert prob_acc == pytest.approx(0.45968522728859024)
assert 'U^i_n - U^i_m = -3.69 kT, U^j_m - U^j_n = 4.46 kT, Total dU: 0.78 kT' in out

# Test 2
swap = (0, 2)
prob_acc_2 = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts)
prob_acc = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts)
out, err = capfd.readouterr()
# dU = (-9.1366697 + 4.9963939)/2.4777098766670016 ~ -1.6710 kT, so p_acc = 1
assert prob_acc_2 == 1
assert prob_acc == 1

def test_accept_or_reject(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
Expand All @@ -809,10 +809,10 @@ def test_weight_correction(self, params_dict):
# Case 1: Perform weight correction (N_cutoff reached)
REXEE.N_cutoff = 5000
REXEE.verbose = False # just to increase code coverage
weights_1 = [[0, 10.304, 20.073, 29.364]]
counts_1 = [[31415, 45701, 55457, 59557]]
weights_1 = REXEE.weight_correction(weights_1, counts_1)
assert np.allclose(weights_1, [
weights = [[0, 10.304, 20.073, 29.364]]
counts = [[31415, 45701, 55457, 59557]]
weights = REXEE.weight_correction(weights, counts)
assert np.allclose(weights, [
[
0,
10.304 + np.log(31415 / 45701),
Expand All @@ -823,10 +823,10 @@ def test_weight_correction(self, params_dict):

# Case 2: Perform weight correction (N_cutoff not reached by both N_k and N_{k-1})
REXEE.verbose = True
weights_2 = [[0, 10.304, 20.073, 29.364]]
counts_2 = [[3141, 4570, 5545, 5955]]
weights_2 = REXEE.weight_correction(weights_2, counts_2)
assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]])
weights = [[0, 10.304, 20.073, 29.364]]
counts = [[3141, 4570, 5545, 5955]]
weights = REXEE.weight_correction(weights, counts)
assert np.allclose(weights, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]])

def test_combine_weights(self, params_dict):
"""
Expand All @@ -841,12 +841,12 @@ def test_combine_weights(self, params_dict):

# Test 1: simple means
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
w_1, g_vec_1 = REXEE.combine_weights(weights)
assert np.allclose(w_1, [
w, g_vec = REXEE.combine_weights(weights)
assert np.allclose(w, [
[0, 2.1, 3.9, 3.5],
[0, 1.8, 1.4, 2.75],
[0, -0.4, 0.95, 1.95]])
assert np.allclose(list(g_vec_1), [0, 2.1, 3.9, 3.5, 4.85, 5.85])
assert np.allclose(list(g_vec), [0, 2.1, 3.9, 3.5, 4.85, 5.85])

# Test 2: weighted means
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
Expand Down
1 change: 0 additions & 1 deletion ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _convert_to_numeric(s):
return s
except (ValueError, AttributeError):
pass
raise ValueError(f"Failed to convert {s} to a numeric value.")


def _get_subplot_dimension(n_panels):
Expand Down

0 comments on commit 5175648

Please sign in to comment.