Skip to content

Commit

Permalink
Merge pull request #69 from wehs7661/ckpt-restart
Browse files Browse the repository at this point in the history
Improve Checkpointing
  • Loading branch information
wehs7661 authored Oct 31, 2024
2 parents 7d5d699 + ba5d975 commit 4ffcb5c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ Here is the help message of :code:`run_REXEE`:
whole-range alchemical weights. This file is a necessary input
if one wants to update the file when extending a weight-
updating simulation. (Default: g_vecs.npy)
-e EQUIL, --equil EQUIL
The file path of the NPY file containing the equilibration times for all simulations
when completing a variable weight REXEE simulation. (Default: equil.npy)
-o OUTPUT, --output OUTPUT
The file path of the output file for logging how replicas
interact with each other. (Default: run_REXEE_log.txt)
Expand Down Expand Up @@ -139,6 +142,7 @@ two files (generated by the existing simulation) as necessary checkpoints:

* One NPY file containing the replica-space trajectories of different configurations, as specified in the input YAML file.
* One NPY file containing the time series of the whole-range alchemical weights, as specified in the input YAML file. This is only needed for extending a weight-updating REXEE simulation.
* One NPY file containing the equilibration times for all simulations or -1 if the simulation is not yet equilibrated. This is only needed if extending a weight-updating REXEE simulation.

In the CLI :code:`run_REXEE`, the class :class:`.ReplicaExchangeEE` is instantiated with the given YAML file, where
the user needs to specify how the replicas should be set up or interact with each
Expand Down
22 changes: 18 additions & 4 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def initialize(args):
help='The file path of the NPY file containing the time series of the whole-range\
alchemical weights. This file is a necessary input if one wants to update the \
file when extending a weight-updating simulation. (Default: g_vecs.npy)')
parser.add_argument('-e',
'--equil',
type=str,
default='equil.npy',
help='The file path of the NPY file containing the equilibration times for all simulations \
when completing a variable weight REXEE simulation. (Default: equil.npy)')
parser.add_argument('-o',
'--output',
type=str,
Expand Down Expand Up @@ -124,11 +130,15 @@ def main():
shutil.rmtree(f'{REXEE.working_dir}/sim_{i}/iteration_{j}')

# Read g_vecs.npy and rep_trajs.npy so that new data can be appended, if any.
# Read equil.npy if running variable weight REXEE simulations
# Note that these two arrays are created in rank 0 and should always be operated in rank 0,
# or broadcasting is required.
REXEE.rep_trajs = [list(i) for i in ckpt_data]
if os.path.isfile(args.g_vecs) is True:
REXEE.g_vecs = [list(i) for i in np.load(args.g_vecs)]
if REXEE.fixed_weights is not True and os.path.isfile(args.equil) is True:
REXEE.equil = np.load(args.equil)
print(REXEE.equil)
else:
start_idx = None

Expand Down Expand Up @@ -322,16 +332,20 @@ def main():
if (i + 1) % REXEE.n_ckpt == 0:
if len(REXEE.g_vecs) != 0:
# Save g_vec as a function of time if weight combination was used.
np.save('g_vecs.npy', REXEE.g_vecs)
np.save(args.g_vecs, REXEE.g_vecs)

print('\n----- Saving .npy files to checkpoint the simulation ---')
np.save('rep_trajs.npy', REXEE.rep_trajs)
np.save(args.ckpt, REXEE.rep_trajs)
if REXEE.fixed_weights is not True:
np.save(args.equil, REXEE.equil)

# Save the npy files at the end of the simulation anyway.
if rank == 0:
if len(REXEE.g_vecs) != 0: # The length will be 0 only if there is no weight combination.
np.save('g_vecs.npy', REXEE.g_vecs)
np.save('rep_trajs.npy', REXEE.rep_trajs)
np.save(args.g_vecs, REXEE.g_vecs)
np.save(args.ckpt, REXEE.rep_trajs)
if REXEE.fixed_weights is not True:
np.save(args.equil, REXEE.equil)

# Step 5: Write a summary for the simulation ensemble
if rank == 0:
Expand Down
11 changes: 10 additions & 1 deletion ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,16 @@ def combine_weights(self, weights, weights_err=None, print_values=True):
dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0])

dg_vec.insert(0, 0)
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])
nan_loc = [i for i, x in enumerate(dg_vec) if np.isnan(x)]
if len(nan_loc) != 0:
g_vec = np.zeros(len(dg_vec))
for i in range(1, len(dg_vec)):
if i in nan_loc:
continue
else:
g_vec[i] = g_vec[i-1] + dg_vec[i]
else:
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])

# (3) Determine the vector of alchemical weights for each replica
weights_modified = np.zeros_like(weights)
Expand Down
14 changes: 14 additions & 0 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,20 @@ def test_combine_weights(self, params_dict):
[0, -0.40723412, 0.95296164, 1.95296164]])
assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501

# Test 3: MT-REXEE
REXEE.n_tot = 27
REXEE.n_sub = 9
REXEE.s = 9
REXEE.n_sim = 3
REXEE.state_ranges = [[0, 1, 2, 3, 4, 5, 6, 7, 8],
[9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26]]
weights = [[0.0, 0.03813, 0.08744, 0.58124, 1.42376, 1.86745, 3.24031, 3.09609, 3.5702],
[0.0, 0.26761, 0.12656, 1.28994, 2.58161, 3.40522, 5.14597, 5.1271, 4.71232],
[0.0, -0.24036, -0.21415, 0.90212, 1.88597, 3.5445, 4.62079, 4.48149, 4.84671]]
w_3, g_vec_3 = REXEE.combine_weights(weights)
assert np.allclose(list(g_vec_3), [0.0, 0.03813, 0.08744, 0.58124, 1.42376, 1.86745, 3.24031, 3.09609, 3.5702, 0.0, 0.26761, 0.12656, 1.28994, 2.58161, 3.40522, 5.14597, 5.1271, 4.71232, 0.0, -0.24036, -0.21415, 0.90212, 1.88597, 3.5445, 4.62079, 4.48149, 4.84671]) # noqa: E501

def test_histogram_correction(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
REXEE.n_tot = 6
Expand Down

0 comments on commit 4ffcb5c

Please sign in to comment.