diff --git a/ensemble_md/tests/test_gmx_parser.py b/ensemble_md/tests/test_gmx_parser.py index 0caef7ee..1a929645 100644 --- a/ensemble_md/tests/test_gmx_parser.py +++ b/ensemble_md/tests/test_gmx_parser.py @@ -99,8 +99,9 @@ def test_write(self): os.remove('test_1.mdp') os.remove('test_2.mdp') + def test_compare_MDPs(): - mdp_list = ['ensemble_md/tests/data/mdp/compare_1.mdp', 'ensemble_md/tests/data/mdp/compare_2.mdp', 'ensemble_md/tests/data/mdp/compare_3.mdp'] + mdp_list = ['ensemble_md/tests/data/mdp/compare_1.mdp', 'ensemble_md/tests/data/mdp/compare_2.mdp', 'ensemble_md/tests/data/mdp/compare_3.mdp'] # noqa: E501 result_1 = gmx_parser.compare_MDPs(mdp_list[:2], print_diff=True) result_2 = gmx_parser.compare_MDPs(mdp_list[1:], print_diff=True) dict_1 = {} # the first two are the same but just in different formats @@ -108,7 +109,7 @@ def test_compare_MDPs(): 'nstdhdl': [100, 10], 'wl_oneovert': [None, 'yes'], 'weight_equil_wl_delta': [None, 0.001], - 'init_lambda_weights': [[0.0, 57.88597, 112.71883, 163.84425, 210.48097, 253.80261, 294.79849, 333.90408, 370.82669, 406.02515, 438.53116, 468.53751, 496.24649, 521.58417, 544.57404, 565.26697, 583.7337, 599.60651, 613.43958, 624.70471, 633.95947, 638.29785, 642.44977, 646.33551, 649.91626, 651.54779, 652.93359, 654.13263, 654.94073, 655.13086, 655.07239, 654.66443, 653.68683, 652.32123, 650.72308, 649.2381, 647.94586, 646.599, 645.52063, 643.99133], None], + 'init_lambda_weights': [[0.0, 57.88597, 112.71883, 163.84425, 210.48097, 253.80261, 294.79849, 333.90408, 370.82669, 406.02515, 438.53116, 468.53751, 496.24649, 521.58417, 544.57404, 565.26697, 583.7337, 599.60651, 613.43958, 624.70471, 633.95947, 638.29785, 642.44977, 646.33551, 649.91626, 651.54779, 652.93359, 654.13263, 654.94073, 655.13086, 655.07239, 654.66443, 653.68683, 652.32123, 650.72308, 649.2381, 647.94586, 646.599, 645.52063, 643.99133], None], # noqa: E501 'wl_ratio': [None, 0.7], 'lmc_weights_equil': [None, 'wl_delta'], 'lmc_stats': ['no', 'wang_landau'], diff --git a/ensemble_md/tests/test_utils.py b/ensemble_md/tests/test_utils.py index cdb13c56..495fc68d 100644 --- a/ensemble_md/tests/test_utils.py +++ b/ensemble_md/tests/test_utils.py @@ -10,10 +10,15 @@ """ Unit tests for the module utils.py. """ +import os import sys +import shutil +import pytest import tempfile +import subprocess import numpy as np from ensemble_md.utils import utils +from unittest.mock import patch, MagicMock def test_logger(): @@ -39,6 +44,37 @@ def test_logger(): sys.stdout = sys.__stdout__ +def test_run_gmx_cmd_success(): + # Mock the subprocess.run return value for a successful execution + mock_successful_return = MagicMock() + mock_successful_return.returncode = 0 + mock_successful_return.stdout = "Simulation complete" + mock_successful_return.stderr = None + + with patch('subprocess.run', return_value=mock_successful_return) as mock_run: + return_code, stdout, stderr = utils.run_gmx_cmd(['gmx', 'mdrun', '-deffnm', 'sys']) + + mock_run.assert_called_once_with(['gmx', 'mdrun', '-deffnm', 'sys'], capture_output=True, text=True, input=None, check=True) # noqa: E501 + assert return_code == 0 + assert stdout == "Simulation complete" + assert stderr is None + + +def test_run_gmx_cmd_failure(): + # Mock the subprocess.run to raise a CalledProcessError for a failed execution + mock_failed_return = MagicMock() + mock_failed_return.returncode = 1 + mock_failed_return.stderr = "Error encountered" + + with patch('subprocess.run') as mock_run: + mock_run.side_effect = [subprocess.CalledProcessError(mock_failed_return.returncode, 'cmd', stderr=mock_failed_return.stderr)] # noqa: E501 + return_code, stdout, stderr = utils.run_gmx_cmd(['gmx', 'mdrun', '-deffnm', 'sys']) + + assert return_code == 1 + assert stdout is None + assert stderr == "Error encountered" + + def test_format_time(): assert utils.format_time(0) == "0.0 second(s)" assert utils.format_time(1) == "1.0 second(s)" @@ -96,3 +132,64 @@ def test_weighted_mean(): mean, err = utils.weighted_mean(vals, errs) assert np.isclose(mean, 2.9997333688841485) assert np.isclose(err, 0.0577311783020254) + + # 3. 0 in errs + vals = [1, 2, 3, 4] + errs = [0, 0.1, 0.1, 0.1] + mean, err = utils.weighted_mean(vals, errs) + assert mean == 2.5 + assert err is None + + +def test_calc_rmse(): + # Test 1 + data = [1, 2, 3, 4, 5] + ref = [2, 4, 6, 8, 10] + expected_rmse = np.sqrt(np.mean((np.array(data) - np.array(ref)) ** 2)) + assert utils.calc_rmse(data, ref) == expected_rmse + + # Test 2 + ref = [1, 2, 3, 4, 5] + expected_rmse = 0 + assert utils.calc_rmse(data, ref) == expected_rmse + + # Test 3 + data = [1, 2, 3] + ref = [1, 2] + with pytest.raises(ValueError): + utils.calc_rmse(data, ref) + + +def test_get_time_metrics(): + log = 'ensemble_md/tests/data/log/EXE.log' + t_metrics = { + 'performance': 23.267, + 't_wall': 3.721, + 't_core': 29.713 + } + assert utils.get_time_metrics(log) == t_metrics + + +def test_analyze_REXEE_time(): + # Set up directories and files + dirs = [f'ensemble_md/tests/data/log/sim_{i}/iteration_{j}' for i in range(2) for j in range(2)] + files = [f'ensemble_md/tests/data/log/EXE_{i}.log' for i in range(4)] + for i in range(4): + os.makedirs(dirs[i]) + shutil.copy(files[i], os.path.join(dirs[i], 'EXE.log')) + + # Test analyze_REXEE_time + # Case 1: Wrong paths + with pytest.raises(FileNotFoundError, match="No sim/iteration directories found."): + t_1, t_2, t_3 = utils.analyze_REXEE_time() # This will try to find files from [natsort.natsorted(glob.glob(f'sim_*/iteration_{i}/*log')) for i in range(n_iter)] # noqa: E501 + + # Case 2: Correct paths + log_files = [[f'ensemble_md/tests/data/log/sim_{i}/iteration_{j}/EXE.log' for i in range(2)] for j in range(2)] + t_1, t_2, t_3 = utils.analyze_REXEE_time(log_files=log_files) + assert t_1 == 2.125 + assert np.isclose(t_2, 0.175) + assert t_3 == [[1.067, 0.94], [1.01, 1.058]] + + # Clean up + for i in range(2): + shutil.rmtree(f'ensemble_md/tests/data/log/sim_{i}') diff --git a/ensemble_md/utils/utils.py b/ensemble_md/utils/utils.py index c595c403..e84878ff 100644 --- a/ensemble_md/utils/utils.py +++ b/ensemble_md/utils/utils.py @@ -284,19 +284,23 @@ def get_time_metrics(log): if 'Time: ' in l: t_metrics['t_core'] = float(l.split()[1]) # s t_metrics['t_wall'] = float(l.split()[2]) # s - break return t_metrics -def analyze_REXEE_time(log_files=None): +def analyze_REXEE_time(n_iter=None, log_files=None): """ Perform simple data analysis on the wall times and performances of all iterations of an REXEE simulation. Parameters ---------- + n_iter : None or int + The number of iterations in the REXEE simulation. If None, the function will try to find the number of + iterations by counting the number of directories named "iteration_*" in the simulation directory + (i.e., :code:`sim_0`) in the current working directory or where the log files are located. log_files : None or list - A list of sorted file names of all log files. + A list of lists log files with the shape of (n_iter, n_replicas). If None, the function will try to find + the log files by searching the current working directory. Returns ------- @@ -308,10 +312,18 @@ def analyze_REXEE_time(log_files=None): t_wall_list : list The list of wall times of finishing each mdrun command. """ - n_iter = len(glob.glob('sim_0/iteration_*')) + if n_iter is None: + if log_files is None: + n_iter = len(glob.glob('sim_0/iteration_*')) + else: + n_iter = len(log_files) + if log_files is None: log_files = [natsort.natsorted(glob.glob(f'sim_*/iteration_{i}/*log')) for i in range(n_iter)] + if len(log_files) == 0: + raise FileNotFoundError("No sim/iteration directories found.") + t_wall_list = [] t_wall_tot, t_sync = 0, 0 for i in range(n_iter):