From 517564866fdf6868b2d8bcc5b7193f633334fae5 Mon Sep 17 00:00:00 2001 From: Wei-Tse Hsu Date: Mon, 22 Apr 2024 17:17:50 +0800 Subject: [PATCH] Further pushed the code coverage --- ensemble_md/tests/test_analyze_free_energy.py | 42 ++++++++++++- ensemble_md/tests/test_analyze_matrix.py | 16 +++++ ensemble_md/tests/test_analyze_traj.py | 59 +++++++++++++++--- ensemble_md/tests/test_gmx_parser.py | 14 +++++ ensemble_md/tests/test_replica_exchange_EE.py | 60 +++++++++---------- ensemble_md/utils/utils.py | 1 - 6 files changed, 149 insertions(+), 43 deletions(-) diff --git a/ensemble_md/tests/test_analyze_free_energy.py b/ensemble_md/tests/test_analyze_free_energy.py index 30f1b9fc..e674d7ac 100644 --- a/ensemble_md/tests/test_analyze_free_energy.py +++ b/ensemble_md/tests/test_analyze_free_energy.py @@ -24,12 +24,12 @@ @patch('ensemble_md.analysis.analyze_free_energy.extract_dHdl') @patch('ensemble_md.analysis.analyze_free_energy.detect_equilibration') @patch('ensemble_md.analysis.analyze_free_energy.subsample_correlated_data') -def test_preprocess_data(mock_corr, mock_equil, mock_extract_dHdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501 +def test_preprocess_data(mock_corr, mock_equil, mock_extract_dhdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501 mock_data, mock_data_series = MagicMock(), MagicMock() mock_alchemlyb.concat.return_value = mock_data mock_subsampling.u_nk2series.return_value = mock_data_series mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series) - mock_equil.return_value = (10, 5, 50) # t, g, Neff_max + mock_equil.return_value = (10, 5, 18) # t, g, Neff_max mock_data_series.__len__.return_value = 100 # For one of the print statements # Set slicing to return different mock objects based on input @@ -53,6 +53,7 @@ def generic_list_slicing(key): mock_data_series.__getitem__.side_effect = slicing_side_effect # so that we can use mock_data_series[t:] mock_data_series_equil = mock_data_series[10:] # Mock the equilibrated data series, given t=10 + # Case 1: data_type = u_nk files = [[f'ensemble_md/tests/data/dhdl/simulation_example/sim_{i}/iteration_{j}/dhdl.xvg' for j in range(3)] for i in range(4)] # noqa: E501 results = analyze_free_energy.preprocess_data(files, 300, 'u_nk') @@ -74,13 +75,48 @@ def generic_list_slicing(key): assert ' Adopted spacing: 1' in out assert ' 10.0% of the u_nk data was in the equilibrium region and therfore discarded.' in out # noqa: E501 assert ' Statistical inefficiency of u_nk: 5.0' in out - assert ' Number of effective samples: 50' in out + assert ' Number of effective samples: 18' in out assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5) assert len(results[0]) == 4 assert results[1] == [10, 10, 10, 10] assert results[2] == [5, 5, 5, 5] + # Case 2: data_type = dHdl + mock_alchemlyb.concat.reset_mock() + mock_subsampling._prepare_input.reset_mock() + mock_subsampling.slicing.reset_mock() + mock_equil.reset_mock() + + mock_subsampling.dhdl2series.return_value = mock_data_series + mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series) + mock_data_series.__len__.return_value = 200 + mock_data_series.values.__len__.return_value = 200 + + results = analyze_free_energy.preprocess_data(files, 300, 'dhdl', t=10, g=5) + out, err = capfd.readouterr() + + for i in range(4): + for j in range(3): + assert mock_extract_dhdl.call_args_list[i * 3 + j] == call(files[i][j], T=300) + assert mock_subsampling._prepare_input.call_args_list[i] == call(mock_data, mock_data_series, drop_duplicates=True, sort=True) # noqa: E501 + assert mock_subsampling.slicing.call_args_list[2 * i] == call(mock_data, step=1) + assert mock_subsampling.slicing.call_args_list[2 * i + 1] == call(mock_data_series, step=1) + assert 'Subsampling and decorrelating the concatenated dhdl data ...' in out + assert ' Adopted spacing: 1' in out + assert ' 5.0% of the dhdl data was in the equilibrium region and therfore discarded.' in out # noqa: E501 + assert ' Statistical inefficiency of dhdl: 5.0' in out + assert ' Number of effective samples: 38' in out + assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5) + + assert len(results[0]) == 4 + assert results[1] == [] + assert results[2] == [] + + # Case 3: Invalid data_type + with pytest.raises(ValueError, match="Invalid data_type. Expected 'u_nk' or 'dhdl'."): + analyze_free_energy.preprocess_data(files, 300, 'xyz') + @pytest.mark.parametrize("method, expected_estimator", [ ("MBAR", "MBAR estimator"), diff --git a/ensemble_md/tests/test_analyze_matrix.py b/ensemble_md/tests/test_analyze_matrix.py index f7523ab1..7ab9bd84 100644 --- a/ensemble_md/tests/test_analyze_matrix.py +++ b/ensemble_md/tests/test_analyze_matrix.py @@ -70,6 +70,10 @@ def test_calc_transmtx(): assert B3 is None assert C3 is None + # Case 4: Invalid simulation type + with pytest.raises(ValueError, match='Invalid simulation type test.'): + analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/EXE.log'), simulation_type='test') + def test_calc_equil_prob(capfd): # Case 1: Right stochastic @@ -87,6 +91,18 @@ def test_calc_equil_prob(capfd): assert 'The input transition matrix is neither right nor left stochastic' in out +def test_calc_t_relax(): + # Case 1: spectral_gap_err is specified + results = analyze_matrix.calc_t_relax(0.5, 0.1, 0.1) + assert results[0] == 0.2 + assert results[1] == 0.1 * 0.1 / 0.5 ** 2 + + # Case 2: spectral_gap_err is not specified + results = analyze_matrix.calc_t_relax(0.5, 0.1) + assert results[0] == 0.2 + assert results[1] is None + + def test_calc_spectral_gap(capfd): # Case 1 (sanity check): doublly stochastic mtx = np.array([[0.5, 0.5], [0.5, 0.5]]) diff --git a/ensemble_md/tests/test_analyze_traj.py b/ensemble_md/tests/test_analyze_traj.py index 64d8f5e2..e3135707 100644 --- a/ensemble_md/tests/test_analyze_traj.py +++ b/ensemble_md/tests/test_analyze_traj.py @@ -153,7 +153,38 @@ def test_stitch_time_series_for_sim(): assert os.path.exists('state_trajs_for_sim.npy') os.remove('state_trajs_for_sim.npy') - # Test 2: Test for discontinuous time series + # Test 2: The case where dhdl is False + # Here we again use dhdl.xvg files but use dhdl=False with col_idx=1, which corresponds to the state index + trajs = analyze_traj.stitch_time_series_for_sim(files, dhdl=False, col_idx=1) + + trajs[0] == [ + 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, + 1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 1, 0, 1, 1, + 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1 + ] + + trajs[1] == [ + 1, 1, 2, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, + 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 2, 3, 3, 3, 2, 2, + 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1 + ] + + trajs[2] == [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 2, 3, 4, 3, 3, 2, + 3, 3, 2, 2, 2, 3, 4, 3, 4, 4, 5, 5, 5, 5, 4, 3, 4, 3, 3, 4, 4 + ] + + trajs[3] == [ + 3, 3, 3, 3, 3, 3, 3, 5, 4, 4, 5, 4, 4, 5, 4, 5, 5, 5, 4, 5, + 4, 4, 5, 4, 5, 5, 4, 5, 5, 5, 4, 5, 5, 4, 5, 4, 5, 4, 5, 5, + 6, 6, 6, 5, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 6, 7, 6, 7 + ] + + assert os.path.exists('state_trajs_for_sim.npy') + os.remove('state_trajs_for_sim.npy') + + # Test 3: Test for discontinuous time series # Here, for sim_2, we exclude the last 5 lines for the dhdl.xvg file in iteration_1 to create a gap save_and_exclude(f'{folder}/sim_2/iteration_1/dhdl.xvg', 5) os.rename(f'{folder}/sim_2/iteration_1/dhdl.xvg', f'{folder}/sim_2/iteration_1/dhdl_temp.xvg') @@ -624,16 +655,26 @@ def test_plot_transit_time(mock_plt): assert mock_plt.ylabel.call_args_list[0] == call('Average transit time from states 0 to k (step)') assert mock_plt.ylabel.call_args_list[1] == call('Average transit time from states k to 0 (step)') assert mock_plt.ylabel.call_args_list[2] == call('Average round-trip time (step)') + assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [ + './t_0k.png', + './t_k0.png', + './t_roundtrip.png', + ] # Case 2: dt = 0.2 ps, fig_prefix = 'test', here we just test the return values mock_plt.reset_mock() - t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2) + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2, fig_prefix='test') t_1_, t_2_, t_3_ = [[1.0, 1.4], [0.8, 0.8]], [[0.8, 0.6], [1.2]], [[1.8, 2.0], [2.0]] for i in range(2): np.testing.assert_array_almost_equal(t_1[i], t_1_[i]) np.testing.assert_array_almost_equal(t_2[i], t_2_[i]) np.testing.assert_array_almost_equal(t_3[i], t_3_[i]) assert u == 'ps' + assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [ + './test_t_0k.png', + './test_t_k0.png', + './test_t_roundtrip.png', + ] # Case 3: dt = 200 ps, long trajs mock_plt.reset_mock() @@ -661,7 +702,7 @@ def test_plot_transit_time(mock_plt): # Case 5: More than 100 round trips so that a histogram is plotted mock_plt.reset_mock() trajs = np.array([[0, 1, 2, 3, 2] * 20000, [0, 1, 3, 2, 1] * 20000]) - t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N) + t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, fig_prefix='test') assert t_1 == [[3] * 20000, [2] * 20000] assert t_2 == [[2] * 19999, [3] * 19999] @@ -715,12 +756,12 @@ def test_plot_transit_time(mock_plt): ] assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(6)] == [ - './t_0k.png', - './hist_t_0k.png', - './t_k0.png', - './hist_t_k0.png', - './t_roundtrip.png', - './hist_t_roundtrip.png' + './test_t_0k.png', + './test_hist_t_0k.png', + './test_t_k0.png', + './test_hist_t_k0.png', + './test_t_roundtrip.png', + './test_hist_t_roundtrip.png' ] diff --git a/ensemble_md/tests/test_gmx_parser.py b/ensemble_md/tests/test_gmx_parser.py index 2ce72239..1437e4ee 100644 --- a/ensemble_md/tests/test_gmx_parser.py +++ b/ensemble_md/tests/test_gmx_parser.py @@ -81,6 +81,20 @@ def test_write(self): mdp = gmx_parser.MDP("ensemble_md/tests/data/expanded.mdp") mdp.write('test_1.mdp', skipempty=False) mdp.write('test_2.mdp', skipempty=True) + + assert os.path.isfile('test_1.mdp') + assert os.path.isfile('test_2.mdp') + + mdp = gmx_parser.MDP('test_1.mdp') + mdp.write(skipempty=True) # This should overwrite the file + + # Check if the files are the same + with open('test_1.mdp', 'r') as f: + lines_1 = f.readlines() + with open('test_2.mdp', 'r') as f: + lines_2 = f.readlines() + assert lines_1 == lines_2 + os.remove('test_1.mdp') os.remove('test_2.mdp') diff --git a/ensemble_md/tests/test_replica_exchange_EE.py b/ensemble_md/tests/test_replica_exchange_EE.py index d2148b69..2b0e9150 100644 --- a/ensemble_md/tests/test_replica_exchange_EE.py +++ b/ensemble_md/tests/test_replica_exchange_EE.py @@ -232,17 +232,17 @@ def test_set_params_warnings(self, params_dict): params_dict['mdp'] = 'ensemble_md/tests/data/expanded_test.mdp' params_dict['N_cutoff'] = 1000 - REXEE_1 = get_REXEE_instance(params_dict) + REXEE = get_REXEE_instance(params_dict) warning_1 = 'Warning: The weight correction/weight combination method is specified but will not be used since the weights are fixed.' # noqa: E501 warning_2 = 'Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.' warning_3 = 'Warning: We recommend setting gen_seed as -1 so the random seed is different for each iteration.' warning_4 = 'Warning: We recommend generating new velocities for each iteration to avoid potential issues with the detailed balance.' # noqa: E501 - assert warning_1 in REXEE_1.warnings - assert warning_2 in REXEE_1.warnings - assert warning_3 in REXEE_1.warnings - assert warning_4 in REXEE_1.warnings + assert warning_1 in REXEE.warnings + assert warning_2 in REXEE.warnings + assert warning_3 in REXEE.warnings + assert warning_4 in REXEE.warnings os.remove(os.path.join(input_path, "expanded_test.mdp")) @@ -426,7 +426,7 @@ def test_print_params(self, capfd, params_dict): # capfd is a fixture in pytest for testing STDOUT REXEE = get_REXEE_instance(params_dict) REXEE.print_params() - out_1, err = capfd.readouterr() + out, err = capfd.readouterr() L = "" L += "Important parameters of REXEE\n=============================\n" L += f"Python version: {sys.version}\n" @@ -451,11 +451,11 @@ def test_print_params(self, capfd, params_dict): L += "Alchemical ranges of each replica in REXEE:\n - Replica 0: States [0, 1, 2, 3, 4, 5]\n" L += " - Replica 1: States [1, 2, 3, 4, 5, 6]\n - Replica 2: States [2, 3, 4, 5, 6, 7]\n" L += " - Replica 3: States [3, 4, 5, 6, 7, 8]\n" - assert out_1 == L + assert out == L REXEE.reformatted_mdp = True # Just to test the case where REXEE.reformatted_mdp is True REXEE.print_params(params_analysis=True) - out_2, err = capfd.readouterr() + out, err = capfd.readouterr() L += "\nWhether to build Markov state models and perform relevant analysis: False\n" L += "Whether to perform free energy calculations: False\n" L += "The step to used in subsampling the DHDL data in free energy calculations, if any: 1\n" @@ -464,7 +464,7 @@ def test_print_params(self, capfd, params_dict): L += "The number of bootstrap iterations in the boostrapping method, if used: 50\n" L += "The random seed to use in bootstrapping, if used: None\n" L += "Note that the input MDP file has been reformatted by replacing hypens with underscores. The original mdp file has been renamed as *backup.mdp.\n" # noqa: E501 - assert out_2 == L + assert out == L REXEE.gro = ['ensemble_md/tests/data/sys.gro', 'ensemble_md/tests/data/sys.gro'] # noqa: E501 REXEE.top = ['ensemble_md/tests/data/sys.top', 'ensemble_md/tests/data/sys.top'] @@ -642,20 +642,20 @@ def test_identify_swappable_pairs(self, params_dict): states = [4, 2, 2, 7] # This would lead to the swappables: [(0, 1), (0, 2), (1, 2)] in the standard case # Case 1: Any case that is not neighboring swap has the same definition for the swappable pairs - swappables_1 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring') - assert swappables_1 == [(0, 1), (0, 2), (1, 2)] + swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring') + assert swappables == [(0, 1), (0, 2), (1, 2)] # Case 2: Neighboring exchange REXEE.proposal = 'neighboring' - swappables_2 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring') - assert swappables_2 == [(0, 1), (1, 2)] + swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring') + assert swappables == [(0, 1), (1, 2)] # Case 3: Non-neighboring exchange, with add_swappables REXEE.proposal = 'exhaustive' REXEE.add_swappables = [[3, 7], [4, 7]] states = [4, 3, 2, 7] # Without add_swappables, the swappables would be [(0, 1), (0, 2), (1, 2)] - swappables_3 = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring', REXEE.add_swappables) # noqa: E501 - assert swappables_3 == [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3)] + swappables = REXEE.identify_swappable_pairs(states, REXEE.state_ranges, REXEE.proposal == 'neighboring', REXEE.add_swappables) # noqa: E501 + assert swappables == [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3)] def test_propose_swap(self, params_dict): random.seed(0) @@ -777,18 +777,18 @@ def test_calc_prob_acc(self, capfd, params_dict): # Test 1 swap = (0, 1) - prob_acc_1 = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts) + prob_acc = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts) out, err = capfd.readouterr() # dU = (-9.1366697 + 11.0623788)/2.4777098766670016 ~ 0.7772 kT, so p_acc = 0.45968522728859024 - assert prob_acc_1 == pytest.approx(0.45968522728859024) + assert prob_acc == pytest.approx(0.45968522728859024) assert 'U^i_n - U^i_m = -3.69 kT, U^j_m - U^j_n = 4.46 kT, Total dU: 0.78 kT' in out # Test 2 swap = (0, 2) - prob_acc_2 = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts) + prob_acc = REXEE.calc_prob_acc(swap, dhdl_files, states, shifts) out, err = capfd.readouterr() # dU = (-9.1366697 + 4.9963939)/2.4777098766670016 ~ -1.6710 kT, so p_acc = 1 - assert prob_acc_2 == 1 + assert prob_acc == 1 def test_accept_or_reject(self, params_dict): REXEE = get_REXEE_instance(params_dict) @@ -809,10 +809,10 @@ def test_weight_correction(self, params_dict): # Case 1: Perform weight correction (N_cutoff reached) REXEE.N_cutoff = 5000 REXEE.verbose = False # just to increase code coverage - weights_1 = [[0, 10.304, 20.073, 29.364]] - counts_1 = [[31415, 45701, 55457, 59557]] - weights_1 = REXEE.weight_correction(weights_1, counts_1) - assert np.allclose(weights_1, [ + weights = [[0, 10.304, 20.073, 29.364]] + counts = [[31415, 45701, 55457, 59557]] + weights = REXEE.weight_correction(weights, counts) + assert np.allclose(weights, [ [ 0, 10.304 + np.log(31415 / 45701), @@ -823,10 +823,10 @@ def test_weight_correction(self, params_dict): # Case 2: Perform weight correction (N_cutoff not reached by both N_k and N_{k-1}) REXEE.verbose = True - weights_2 = [[0, 10.304, 20.073, 29.364]] - counts_2 = [[3141, 4570, 5545, 5955]] - weights_2 = REXEE.weight_correction(weights_2, counts_2) - assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]]) + weights = [[0, 10.304, 20.073, 29.364]] + counts = [[3141, 4570, 5545, 5955]] + weights = REXEE.weight_correction(weights, counts) + assert np.allclose(weights, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]]) def test_combine_weights(self, params_dict): """ @@ -841,12 +841,12 @@ def test_combine_weights(self, params_dict): # Test 1: simple means weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] - w_1, g_vec_1 = REXEE.combine_weights(weights) - assert np.allclose(w_1, [ + w, g_vec = REXEE.combine_weights(weights) + assert np.allclose(w, [ [0, 2.1, 3.9, 3.5], [0, 1.8, 1.4, 2.75], [0, -0.4, 0.95, 1.95]]) - assert np.allclose(list(g_vec_1), [0, 2.1, 3.9, 3.5, 4.85, 5.85]) + assert np.allclose(list(g_vec), [0, 2.1, 3.9, 3.5, 4.85, 5.85]) # Test 2: weighted means weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] diff --git a/ensemble_md/utils/utils.py b/ensemble_md/utils/utils.py index f072c8ae..e74f5579 100644 --- a/ensemble_md/utils/utils.py +++ b/ensemble_md/utils/utils.py @@ -157,7 +157,6 @@ def _convert_to_numeric(s): return s except (ValueError, AttributeError): pass - raise ValueError(f"Failed to convert {s} to a numeric value.") def _get_subplot_dimension(n_panels):