Skip to content

Commit

Permalink
Add unit tests for utils.py and tweaked utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 26, 2024
1 parent 7e2fb7d commit 194235c
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 6 deletions.
5 changes: 3 additions & 2 deletions ensemble_md/tests/test_gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,17 @@ def test_write(self):
os.remove('test_1.mdp')
os.remove('test_2.mdp')


def test_compare_MDPs():
mdp_list = ['ensemble_md/tests/data/mdp/compare_1.mdp', 'ensemble_md/tests/data/mdp/compare_2.mdp', 'ensemble_md/tests/data/mdp/compare_3.mdp']
mdp_list = ['ensemble_md/tests/data/mdp/compare_1.mdp', 'ensemble_md/tests/data/mdp/compare_2.mdp', 'ensemble_md/tests/data/mdp/compare_3.mdp'] # noqa: E501
result_1 = gmx_parser.compare_MDPs(mdp_list[:2], print_diff=True)
result_2 = gmx_parser.compare_MDPs(mdp_list[1:], print_diff=True)
dict_1 = {} # the first two are the same but just in different formats
dict_2 = {
'nstdhdl': [100, 10],
'wl_oneovert': [None, 'yes'],
'weight_equil_wl_delta': [None, 0.001],
'init_lambda_weights': [[0.0, 57.88597, 112.71883, 163.84425, 210.48097, 253.80261, 294.79849, 333.90408, 370.82669, 406.02515, 438.53116, 468.53751, 496.24649, 521.58417, 544.57404, 565.26697, 583.7337, 599.60651, 613.43958, 624.70471, 633.95947, 638.29785, 642.44977, 646.33551, 649.91626, 651.54779, 652.93359, 654.13263, 654.94073, 655.13086, 655.07239, 654.66443, 653.68683, 652.32123, 650.72308, 649.2381, 647.94586, 646.599, 645.52063, 643.99133], None],
'init_lambda_weights': [[0.0, 57.88597, 112.71883, 163.84425, 210.48097, 253.80261, 294.79849, 333.90408, 370.82669, 406.02515, 438.53116, 468.53751, 496.24649, 521.58417, 544.57404, 565.26697, 583.7337, 599.60651, 613.43958, 624.70471, 633.95947, 638.29785, 642.44977, 646.33551, 649.91626, 651.54779, 652.93359, 654.13263, 654.94073, 655.13086, 655.07239, 654.66443, 653.68683, 652.32123, 650.72308, 649.2381, 647.94586, 646.599, 645.52063, 643.99133], None], # noqa: E501
'wl_ratio': [None, 0.7],
'lmc_weights_equil': [None, 'wl_delta'],
'lmc_stats': ['no', 'wang_landau'],
Expand Down
97 changes: 97 additions & 0 deletions ensemble_md/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
"""
Unit tests for the module utils.py.
"""
import os
import sys
import shutil
import pytest
import tempfile
import subprocess
import numpy as np
from ensemble_md.utils import utils
from unittest.mock import patch, MagicMock


def test_logger():
Expand All @@ -39,6 +44,37 @@ def test_logger():
sys.stdout = sys.__stdout__


def test_run_gmx_cmd_success():
# Mock the subprocess.run return value for a successful execution
mock_successful_return = MagicMock()
mock_successful_return.returncode = 0
mock_successful_return.stdout = "Simulation complete"
mock_successful_return.stderr = None

with patch('subprocess.run', return_value=mock_successful_return) as mock_run:
return_code, stdout, stderr = utils.run_gmx_cmd(['gmx', 'mdrun', '-deffnm', 'sys'])

mock_run.assert_called_once_with(['gmx', 'mdrun', '-deffnm', 'sys'], capture_output=True, text=True, input=None, check=True) # noqa: E501
assert return_code == 0
assert stdout == "Simulation complete"
assert stderr is None


def test_run_gmx_cmd_failure():
# Mock the subprocess.run to raise a CalledProcessError for a failed execution
mock_failed_return = MagicMock()
mock_failed_return.returncode = 1
mock_failed_return.stderr = "Error encountered"

with patch('subprocess.run') as mock_run:
mock_run.side_effect = [subprocess.CalledProcessError(mock_failed_return.returncode, 'cmd', stderr=mock_failed_return.stderr)] # noqa: E501
return_code, stdout, stderr = utils.run_gmx_cmd(['gmx', 'mdrun', '-deffnm', 'sys'])

assert return_code == 1
assert stdout is None
assert stderr == "Error encountered"


def test_format_time():
assert utils.format_time(0) == "0.0 second(s)"
assert utils.format_time(1) == "1.0 second(s)"
Expand Down Expand Up @@ -96,3 +132,64 @@ def test_weighted_mean():
mean, err = utils.weighted_mean(vals, errs)
assert np.isclose(mean, 2.9997333688841485)
assert np.isclose(err, 0.0577311783020254)

# 3. 0 in errs
vals = [1, 2, 3, 4]
errs = [0, 0.1, 0.1, 0.1]
mean, err = utils.weighted_mean(vals, errs)
assert mean == 2.5
assert err is None


def test_calc_rmse():
# Test 1
data = [1, 2, 3, 4, 5]
ref = [2, 4, 6, 8, 10]
expected_rmse = np.sqrt(np.mean((np.array(data) - np.array(ref)) ** 2))
assert utils.calc_rmse(data, ref) == expected_rmse

# Test 2
ref = [1, 2, 3, 4, 5]
expected_rmse = 0
assert utils.calc_rmse(data, ref) == expected_rmse

# Test 3
data = [1, 2, 3]
ref = [1, 2]
with pytest.raises(ValueError):
utils.calc_rmse(data, ref)


def test_get_time_metrics():
log = 'ensemble_md/tests/data/log/EXE.log'
t_metrics = {
'performance': 23.267,
't_wall': 3.721,
't_core': 29.713
}
assert utils.get_time_metrics(log) == t_metrics


def test_analyze_REXEE_time():
# Set up directories and files
dirs = [f'ensemble_md/tests/data/log/sim_{i}/iteration_{j}' for i in range(2) for j in range(2)]
files = [f'ensemble_md/tests/data/log/EXE_{i}.log' for i in range(4)]
for i in range(4):
os.makedirs(dirs[i])
shutil.copy(files[i], os.path.join(dirs[i], 'EXE.log'))

# Test analyze_REXEE_time
# Case 1: Wrong paths
with pytest.raises(FileNotFoundError, match="No sim/iteration directories found."):
t_1, t_2, t_3 = utils.analyze_REXEE_time() # This will try to find files from [natsort.natsorted(glob.glob(f'sim_*/iteration_{i}/*log')) for i in range(n_iter)] # noqa: E501

# Case 2: Correct paths
log_files = [[f'ensemble_md/tests/data/log/sim_{i}/iteration_{j}/EXE.log' for i in range(2)] for j in range(2)]
t_1, t_2, t_3 = utils.analyze_REXEE_time(log_files=log_files)
assert t_1 == 2.125
assert np.isclose(t_2, 0.175)
assert t_3 == [[1.067, 0.94], [1.01, 1.058]]

# Clean up
for i in range(2):
shutil.rmtree(f'ensemble_md/tests/data/log/sim_{i}')
20 changes: 16 additions & 4 deletions ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,23 @@ def get_time_metrics(log):
if 'Time: ' in l:
t_metrics['t_core'] = float(l.split()[1]) # s
t_metrics['t_wall'] = float(l.split()[2]) # s
break

return t_metrics


def analyze_REXEE_time(log_files=None):
def analyze_REXEE_time(n_iter=None, log_files=None):
"""
Perform simple data analysis on the wall times and performances of all iterations of an REXEE simulation.
Parameters
----------
n_iter : None or int
The number of iterations in the REXEE simulation. If None, the function will try to find the number of
iterations by counting the number of directories named "iteration_*" in the simulation directory
(i.e., :code:`sim_0`) in the current working directory or where the log files are located.
log_files : None or list
A list of sorted file names of all log files.
A list of lists log files with the shape of (n_iter, n_replicas). If None, the function will try to find
the log files by searching the current working directory.
Returns
-------
Expand All @@ -308,10 +312,18 @@ def analyze_REXEE_time(log_files=None):
t_wall_list : list
The list of wall times of finishing each mdrun command.
"""
n_iter = len(glob.glob('sim_0/iteration_*'))
if n_iter is None:
if log_files is None:
n_iter = len(glob.glob('sim_0/iteration_*'))
else:
n_iter = len(log_files)

if log_files is None:
log_files = [natsort.natsorted(glob.glob(f'sim_*/iteration_{i}/*log')) for i in range(n_iter)]

if len(log_files) == 0:
raise FileNotFoundError("No sim/iteration directories found.")

t_wall_list = []
t_wall_tot, t_sync = 0, 0
for i in range(n_iter):
Expand Down

0 comments on commit 194235c

Please sign in to comment.