Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Checkpointing #69

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ Here is the help message of :code:`run_REXEE`:
whole-range alchemical weights. This file is a necessary input
if one wants to update the file when extending a weight-
updating simulation. (Default: g_vecs.npy)
-e EQUIL, --equil EQUIL
The file path of the NPY file containing the equilibration times for all simulations
when completing a variable weight REXEE simulation. (Default: equil.npy)
-o OUTPUT, --output OUTPUT
The file path of the output file for logging how replicas
interact with each other. (Default: run_REXEE_log.txt)
Expand Down Expand Up @@ -139,6 +142,7 @@ two files (generated by the existing simulation) as necessary checkpoints:

* One NPY file containing the replica-space trajectories of different configurations, as specified in the input YAML file.
* One NPY file containing the time series of the whole-range alchemical weights, as specified in the input YAML file. This is only needed for extending a weight-updating REXEE simulation.
* One NPY file containing the equilibration times for all simulations or -1 if the simulation is not yet equilibrated. This is only needed if extending a weight-updating REXEE simulation.

In the CLI :code:`run_REXEE`, the class :class:`.ReplicaExchangeEE` is instantiated with the given YAML file, where
the user needs to specify how the replicas should be set up or interact with each
Expand Down
22 changes: 18 additions & 4 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def initialize(args):
help='The file path of the NPY file containing the time series of the whole-range\
alchemical weights. This file is a necessary input if one wants to update the \
file when extending a weight-updating simulation. (Default: g_vecs.npy)')
parser.add_argument('-e',
'--equil',
type=str,
default='equil.npy',
help='The file path of the NPY file containing the equilibration times for all simulations \
when completing a variable weight REXEE simulation. (Default: equil.npy)')
parser.add_argument('-o',
'--output',
type=str,
Expand Down Expand Up @@ -124,11 +130,15 @@ def main():
shutil.rmtree(f'{REXEE.working_dir}/sim_{i}/iteration_{j}')

# Read g_vecs.npy and rep_trajs.npy so that new data can be appended, if any.
# Read equil.npy if running variable weight REXEE simulations
# Note that these two arrays are created in rank 0 and should always be operated in rank 0,
# or broadcasting is required.
REXEE.rep_trajs = [list(i) for i in ckpt_data]
if os.path.isfile(args.g_vecs) is True:
REXEE.g_vecs = [list(i) for i in np.load(args.g_vecs)]
if REXEE.fixed_weights is not True and os.path.isfile(args.equil) is True:
REXEE.equil = np.load(args.equil)
print(REXEE.equil)
else:
start_idx = None

Expand Down Expand Up @@ -322,16 +332,20 @@ def main():
if (i + 1) % REXEE.n_ckpt == 0:
if len(REXEE.g_vecs) != 0:
# Save g_vec as a function of time if weight combination was used.
np.save('g_vecs.npy', REXEE.g_vecs)
np.save(args.g_vecs, REXEE.g_vecs)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the input argument g_vecs was intended to be for specifying an input file like g_vecs.npy for extending the simulation (same for the rep_trajs), but I do think it makes sense to have them specify the output files to be saved as well. I'll tweak the help message of the argument parser.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm actually I think the current help message is fine.


print('\n----- Saving .npy files to checkpoint the simulation ---')
np.save('rep_trajs.npy', REXEE.rep_trajs)
np.save(args.ckpt, REXEE.rep_trajs)
if REXEE.fixed_weights is not True:
np.save(args.equil, REXEE.equil)

# Save the npy files at the end of the simulation anyway.
if rank == 0:
if len(REXEE.g_vecs) != 0: # The length will be 0 only if there is no weight combination.
np.save('g_vecs.npy', REXEE.g_vecs)
np.save('rep_trajs.npy', REXEE.rep_trajs)
np.save(args.g_vecs, REXEE.g_vecs)
np.save(args.ckpt, REXEE.rep_trajs)
if REXEE.fixed_weights is not True:
np.save(args.equil, REXEE.equil)

# Step 5: Write a summary for the simulation ensemble
if rank == 0:
Expand Down
11 changes: 10 additions & 1 deletion ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,16 @@ def combine_weights(self, weights, weights_err=None, print_values=True):
dg_vec.append(utils.weighted_mean(dg_list, dg_err_list)[0])

dg_vec.insert(0, 0)
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])
nan_loc = [i for i, x in enumerate(dg_vec) if np.isnan(x)]
if len(nan_loc) != 0:
g_vec = np.zeros(len(dg_vec))
for i in range(1, len(dg_vec)):
if i in nan_loc:
continue
else:
g_vec[i] = g_vec[i-1] + dg_vec[i]
else:
g_vec = np.array([sum(dg_vec[:(i + 1)]) for i in range(len(dg_vec))])

# (3) Determine the vector of alchemical weights for each replica
weights_modified = np.zeros_like(weights)
Expand Down
14 changes: 14 additions & 0 deletions ensemble_md/tests/test_replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,20 @@ def test_combine_weights(self, params_dict):
[0, -0.40723412, 0.95296164, 1.95296164]])
assert np.allclose(list(g_vec_2), [0, 2.1, 3.861407249466951, 3.4541731330165306, 4.814368891580968, 5.814368891580968]) # noqa: E501

# Test 3: MT-REXEE
REXEE.n_tot = 27
REXEE.n_sub = 9
REXEE.s = 9
REXEE.n_sim = 3
REXEE.state_ranges = [[0, 1, 2, 3, 4, 5, 6, 7, 8],
[9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26]]
weights = [[0.0, 0.03813, 0.08744, 0.58124, 1.42376, 1.86745, 3.24031, 3.09609, 3.5702],
[0.0, 0.26761, 0.12656, 1.28994, 2.58161, 3.40522, 5.14597, 5.1271, 4.71232],
[0.0, -0.24036, -0.21415, 0.90212, 1.88597, 3.5445, 4.62079, 4.48149, 4.84671]]
w_3, g_vec_3 = REXEE.combine_weights(weights)
assert np.allclose(list(g_vec_3), [0.0, 0.03813, 0.08744, 0.58124, 1.42376, 1.86745, 3.24031, 3.09609, 3.5702, 0.0, 0.26761, 0.12656, 1.28994, 2.58161, 3.40522, 5.14597, 5.1271, 4.71232, 0.0, -0.24036, -0.21415, 0.90212, 1.88597, 3.5445, 4.62079, 4.48149, 4.84671]) # noqa: E501

def test_histogram_correction(self, params_dict):
REXEE = get_REXEE_instance(params_dict)
REXEE.n_tot = 6
Expand Down
Loading