From 33e1e6e3a9999de63de54e8e97b304ebb3971e87 Mon Sep 17 00:00:00 2001 From: Wei-Tse Hsu Date: Tue, 28 Mar 2023 01:28:25 -0600 Subject: [PATCH] Implemented the method of exhaustive swaps --- docs/api/api_ensemble_EXE.rst | 4 +- docs/theory.rst | 39 ++++++-- ensemble_md/ensemble_EXE.py | 132 ++++++++++++++----------- ensemble_md/tests/test_ensemble_EXE.py | 65 ++++++++---- 4 files changed, 157 insertions(+), 83 deletions(-) diff --git a/docs/api/api_ensemble_EXE.rst b/docs/api/api_ensemble_EXE.rst index 04f96008..2dc5bdba 100644 --- a/docs/api/api_ensemble_EXE.rst +++ b/docs/api/api_ensemble_EXE.rst @@ -1,4 +1,6 @@ ensemble\_md.ensemble_EXE ========================= -.. automodule:: ensemble_md.ensemble_EXE \ No newline at end of file +.. automodule:: ensemble_md.ensemble_EXE + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/theory.rst b/docs/theory.rst index b4c25f04..8d6ac23b 100644 --- a/docs/theory.rst +++ b/docs/theory.rst @@ -154,16 +154,25 @@ For each proposed swap, we calculate the acceptance ratio to decide whether the In greater detail, this scheme can be decomposed into the following steps: - **Step 1**: Identify the list of swappable pairs. - - **Step 2**: Randomly draw a pair from the list of swappable pairs. - - **Step 3**: Update the list of swappable pairs by removing pair(s) involving replicas drawn in Step 2. - - **Step 4**: Repeat Step 2 and 3 until the list of swappable pairs is empty. - - **Step 5**: For each of the pairs drawn in Step 2, calculate the accpetance ratio (using the specified acceptance scheme) to decide whether the coordinates - of the pair of replicas should be swapped. + - **Step 2**: Randomly draw a pair from the list of swappable pairs. + - **Step 3**: Calculate the acceptance ratio for the drawn pair to decide whether the swap should be accepted. + Then, perform or reject the swap. + - **Step 4**: Update the list of swappable pairs by removing pair(s) that involve any replica in the drawn pair in Step 2. + - **Step 5**: Repeat Steps 2 to 4 until the list of swappable pairs is empty. -Note that in this method, - - - No replicas should be involved in more than one proposed swap. - - Given :math:`N` alchemical intermediate states in total, one can at most perform :math:`\lfloor N \rfloor` swaps. +Note that + + - In this method, no replicas should be involved in more than one proposed swap. + - Given :math:`N` alchemical intermediate states in total, one can at most perform :math:`\lfloor N \rfloor` swaps with this method. + - While this method can lead to multiple attempted swaps, these swaps are entirely indepdent of each other, which is + different from the method of multiple swaps introduced below. + - Importantly, whether the swap in Step 3 is accepted or rejected does not influence the update of the list in Step 4 at all. + This is different from the method of multiple swaps introduced in the next section, where the updated list of swappable pairs depends on + the acceptance/rejection of the current attempted swap. + - Since all swaps are independent, instead of calculating and acceptance ratio and performing swaps separately (as done in Step 3 in the procedure above), one + can choose to calculates all acceptance ratios for all drawn pairs and perform all swaps at the same time at the end. + We chose to implement the former in :obj:`.get_swapping_pattern` since this is more consistent with the protocol of the other proposal schemes + , hence easier to code. .. _doc_multiple_swaps: @@ -390,6 +399,18 @@ values from the log file because alchemical weights from in the log files corres sampling different alchemical ranges would have different references. Therefore, only values such as :math:`g^i_n-g^i_m` and :math:`g^j_m-g^j_n` make sense, even if they are as interesting as :math:`g^i_n-g^j_n` and :math:`g^j_m-g^i_m`. +2.4. How is swapping performed? +------------------------------- +As implied in :ref:`doc_basic_idea`, in an EEXE simulation, we could either choose to swap configurations +(via swapping GRO files) or replicas (via swapping MDP files). In this package, we chose the former when +implementing the EEXE algorithm. Specifically, in the CLI :code:`run_EEXE`, the function :obj:`.get_swapping_pattern` +is called once for each iteration and returns a list :code:`swap_pattern` that informs :code:`run_EEXE` how +the GRO files should be swapped. (To better understand the list :code:`swap_pattern`, see the docstring of +the function :obj:`.get_swapping_pattern`.) Internally, the function :obj:`.get_swapping_pattern` not only swaps +the list :code:`swap_pattern` when an attempted move is accepted, but also swaps elements in lists that contains +state shifts, weights, paths to the DHDL files, state ranges, and the attribute :code:`configs`, but not the elements +in the list of states. Check the source code of :obj :`.get_swapping_pattern` if you want to understand the details. + .. _doc_w_schemes: 3. Weight combination diff --git a/ensemble_md/ensemble_EXE.py b/ensemble_md/ensemble_EXE.py index 0399f69d..e2160572 100644 --- a/ensemble_md/ensemble_EXE.py +++ b/ensemble_md/ensemble_EXE.py @@ -329,6 +329,7 @@ def print_params(self, params_analysis=False): print(f'Simulation inputs: {self.gro}, {self.top}, {self.mdp}') print(f"Verbose log file: {self.verbose}") print(f"Whether the replicas run in parallel: {self.parallel}") + print(f"Proposal scheme: {self.proposal}") print(f"Acceptance scheme for swapping simulations: {self.acceptance}") print(f"Scheme for combining weights: {self.w_scheme}") print(f"Histogram cutoff: {self.N_cutoff}") @@ -610,18 +611,18 @@ def propose_swap(self, swappables): def get_swapping_pattern(self, dhdl_files, states, weights): """ - A list (:code:`swap_pattern`) that represents how the configurations should be swapped in the next iteration. - The indices of the output list correspond to the simulation/replica indices, and the values represent the - configuration indices in the corresponding simulation/replica. For example, if the swapping pattern is - :code:`[0, 2, 1, 3]`, it means that in the next iteration, replicas 0, 1, 2, 3 should sample - configurations 0, 2, 1, 3, respectively, where configurations 0, 1, 2, 3 here are defined as whatever + Generates a list (:code:`swap_pattern`) that represents how the configurations should be swapped in the + next iteration. The indices of the output list correspond to the simulation/replica indices, and the + values represent the configuration indices in the corresponding simulation/replica. For example, if the + swapping pattern is :code:`[0, 2, 1, 3]`, it means that in the next iteration, replicas 0, 1, 2, 3 should + sample configurations 0, 2, 1, 3, respectively, where configurations 0, 1, 2, 3 here are defined as whatever configurations are in replicas 0, 1, 2, 3 in the CURRENT iteration (not iteration 0), respectively. - In an EEXE simulation, where each iteration requires one call of this function, :code:`swap_pattern` is always - initialized as :code:`[0, 1, 2, 3, ...]` and gets updated once every attempted swap. On the other hand, the - attribute :code:`configs`, which is only initialized once at the very beginning of the EEXE simulation - (iteration 0), gets updated once every iteration (which could include multiple attempted swaps when - the parameter :code:`n_ex` is larger than 1 (and :code:`proposal` is :code:`multiple`).) + Notably, when this function is called (e.g. once every iteration in an EEXE simulation), the output + list :code:`swap_pattern` is always initialized as :code:`[0, 1, 2, 3, ...]` and gets updated once every + attempted swap. This is different from the attribute :code:`configs`, which is only initialized at the + very beginning of the entire EEXE simulation (iteration 0), though :code:`configs` also gets updated with + :code:`swap_pattern` once every attempted swap in this function. Parameters ---------- @@ -641,10 +642,14 @@ def get_swapping_pattern(self, dhdl_files, states, weights): swap_pattern : list A list that represents how the replicas should be swapped. """ - if self.proposal != 'multiple': - n_ex = 1 - else: # multiple swaps + if self.proposal == 'exhaustive': + n_ex = int(np.floor(self.n_sim / 2)) # This is the maximum, not necessarily the number that will always be reached. # noqa + n_ex_exhaustive = 0 # The actual number of swaps atttempted. + else: + n_ex = 1 # single swap or neighboring swap + else: + # multiple swaps if self.n_ex == 'N^3': n_ex = self.n_tot ** 3 else: @@ -664,52 +669,69 @@ def get_swapping_pattern(self, dhdl_files, states, weights): print('n_ex is set back to 1 since there is only 1 swappable pair.') n_ex = 1 - self.n_swap_attempts += n_ex - print(f"Swappable pairs: {swappables}") for i in range(n_ex): - swap = self.propose_swap(swappables) - print(f'\nProposed swap: {swap}') - if swap == []: - print('No swap is proposed because there is no swappable pair at all.') - break # no need to re-identify swappable pairs and draw new samples + # Update the list of swappable pairs starting from the 2nd attempted swap for the exhaustive swap method. + if self.proposal == 'exhaustive' and i >= 1: + # Note that this should be done regardless of the acceptance/rejection of the previously drawn pairs. + # Also note that at this point, swap is still the last attempted swap. + swappables = [i for i in swappables if set(i).intersection(set(swap)) == set()] # noqa: F821 + print(f'\nRemaining swappable pairs: {swappables}') + + if len(swappables) == 0 and self.proposal == 'exhaustive': + # This should only happen when the method of exhaustive swaps is used. + print(f'{n_ex_exhaustive} swap(s) have been attempted to exhaustively explore all possible swaps.') + if i == 0: + self.n_swap_attempts += 1 + self.n_rejected += 1 + break else: - if self.verbose is True: - print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501 - - # Calculate the acceptance ratio and decide whether to accept the swap. - prob_acc = self.calc_prob_acc(swap, dhdl_files, states, shifts, weights) - swap_bool = self.accept_or_reject(prob_acc) - - # Theoretically, in an EEXE simulation, we could either choose to swap configurations (via - # swapping GRO files) or replicas (via swapping MDP files). In ensemble_md package, we chose the - # former when implementing the EEXE algorithm. Specifically, in the CLI `run_EEXE`, `swap_pattern` - # is used to swap the GRO files. Therefore, when an attempted swap is accetped and `swap_pattern` - # is updated, we also need to update the variables `shifts`, `weights`, `dhdl_files`, `state_ranges`, - # `self.configs` but not anything else. Otherwise, incorrect results will be produced. To better - # understand this, one can refer to our unit test for get_swapping_pattern and calc_prob_acc, set - # checkpoints and examine why the variables should/should not be updated. - - if swap_bool is True: - # The assignments need to be done at the same time in just one line. - # states[swap[0]], states[swap[1]] = states[swap[1]], states[swap[0]] - shifts[swap[0]], shifts[swap[1]] = shifts[swap[1]], shifts[swap[0]] - weights[swap[0]], weights[swap[1]] = weights[swap[1]], weights[swap[0]] - dhdl_files[swap[0]], dhdl_files[swap[1]] = dhdl_files[swap[1]], dhdl_files[swap[0]] - swap_pattern[swap[0]], swap_pattern[swap[1]] = swap_pattern[swap[1]], swap_pattern[swap[0]] - state_ranges[swap[0]], state_ranges[swap[1]] = state_ranges[swap[1]], state_ranges[swap[0]] - self.configs[swap[0]], self.configs[swap[1]] = self.configs[swap[1]], self.configs[swap[0]] - - if n_ex > 1: # must be multiple swaps - # After state_ranges have been updated, we re-identify the swappable pairs. - # Notably, states_copy (instead of states) should be used. (They could be different.) - swappables = self.identify_swappable_pairs(states_copy, state_ranges) - print(f" New swappable pairs: {swappables}") + self.n_swap_attempts += 1 + if self.proposal == 'exhaustive': + n_ex_exhaustive += 1 + + swap = self.propose_swap(swappables) + print(f'\nProposed swap: {swap}') + if swap == []: + print('No swap is proposed because there is no swappable pair at all.') + break # no need to re-identify swappable pairs and draw new samples else: - # In this case, there is no need to update the swappables - pass - - print(f' Current list of configurations: {self.configs}') + if self.verbose is True and self.proposal != 'exhaustive': + print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501 + + # Calculate the acceptance ratio and decide whether to accept the swap. + prob_acc = self.calc_prob_acc(swap, dhdl_files, states, shifts, weights) + swap_bool = self.accept_or_reject(prob_acc) + + # Theoretically, in an EEXE simulation, we could either choose to swap configurations (via + # swapping GRO files) or replicas (via swapping MDP files). In ensemble_md package, we chose the + # former when implementing the EEXE algorithm. Specifically, in the CLI `run_EEXE`, `swap_pattern` + # is used to swap the GRO files. Therefore, when an attempted swap is accetped and `swap_pattern` + # is updated, we also need to update the variables `shifts`, `weights`, `dhdl_files`, + # `state_ranges`, `self.configs` but not anything else. Otherwise, incorrect results will be + # produced. To better understand this, one can refer to our unit test for get_swapping_pattern + # and calc_prob_acc, set checkpoints and examine why the variables should/should not be updated. + + if swap_bool is True: + # The assignments need to be done at the same time in just one line. + # states[swap[0]], states[swap[1]] = states[swap[1]], states[swap[0]] + shifts[swap[0]], shifts[swap[1]] = shifts[swap[1]], shifts[swap[0]] + weights[swap[0]], weights[swap[1]] = weights[swap[1]], weights[swap[0]] + dhdl_files[swap[0]], dhdl_files[swap[1]] = dhdl_files[swap[1]], dhdl_files[swap[0]] + swap_pattern[swap[0]], swap_pattern[swap[1]] = swap_pattern[swap[1]], swap_pattern[swap[0]] + state_ranges[swap[0]], state_ranges[swap[1]] = state_ranges[swap[1]], state_ranges[swap[0]] + self.configs[swap[0]], self.configs[swap[1]] = self.configs[swap[1]], self.configs[swap[0]] + + if n_ex > 1 and self.proposal == 'multiple': # must be multiple swaps + # After state_ranges have been updated, we re-identify the swappable pairs. + # Notably, states_copy (instead of states) should be used. (They could be different.) + swappables = self.identify_swappable_pairs(states_copy, state_ranges) + print(f" New swappable pairs: {swappables}") + else: + # In this case, there is no need to update the swappables + pass + + print(f' Current list of configurations: {self.configs}') if self.verbose is False: print(f'\n{n_ex} swap(s) have been proposed.') diff --git a/ensemble_md/tests/test_ensemble_EXE.py b/ensemble_md/tests/test_ensemble_EXE.py index 16af8360..1e620c16 100644 --- a/ensemble_md/tests/test_ensemble_EXE.py +++ b/ensemble_md/tests/test_ensemble_EXE.py @@ -13,6 +13,7 @@ import os import sys import yaml +import copy import random import shutil import pytest @@ -321,6 +322,7 @@ def test_print_params(self, capfd, params_dict): L += f"Python version: {sys.version}\ngmxapi version: {gmx.__version__}\nensemble_md version: {ensemble_md.__version__}\n" # noqa: E501 L += "Simulation inputs: ensemble_md/tests/data/sys.gro, ensemble_md/tests/data/sys.top, ensemble_md/tests/data/expanded.mdp\n" # noqa: E501 L += "Verbose log file: True\nWhether the replicas run in parallel: False\n" + L += "Proposal scheme: exhaustive\n" L += "Acceptance scheme for swapping simulations: metropolis\nScheme for combining weights: None\n" L += "Histogram cutoff: 1000\nNumber of replicas: 4\nNumber of iterations: 10\n" L += "Number of attempted swaps in one exchange interval: N^3\n" @@ -461,9 +463,8 @@ def test_propose_swap(self, params_dict): assert swap_2 == (1, 2) def test_get_swapping_pattern(self, params_dict): - EEXE = get_EEXE_instance(params_dict) # state_ranges: 0-5, 1-6, ..., 3-8 - - # weights are obtained from the log files in data/log + # weights are obtained from the log files in data/log, where the last states are 5, 2, 2, 8 (global indices) + # state_ranges are: 0-5, 1-6, ..., 3-8 weights = [ [0, 1.03101, 2.55736, 3.63808, 4.47220, 6.13408], [0, 1.22635, 2.30707, 2.44120, 4.10308, 6.03106], @@ -472,54 +473,82 @@ def test_get_swapping_pattern(self, params_dict): dhdl_files = [os.path.join(input_path, f"dhdl/dhdl_{i}.xvg") for i in range(4)] # Case 1: Empty swap list + EEXE = get_EEXE_instance(params_dict) EEXE.verbose = False states = [0, 6, 7, 8] # No swappable pairs - pattern_1 = EEXE.get_swapping_pattern(dhdl_files, states, weights) + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_1 = EEXE.get_swapping_pattern(f, states, w) assert EEXE.n_swap_attempts == 1 - assert EEXE.n_rejected == 0 + assert EEXE.n_rejected == 1 assert pattern_1 == [0, 1, 2, 3] # Case 2: Single swap (proposal = 'single') random.seed(0) + EEXE = get_EEXE_instance(params_dict) EEXE.verbose = True EEXE.proposal = 'single' # n_ex will be set to 1 automatically. states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept - pattern_2 = EEXE.get_swapping_pattern(dhdl_files, states, weights) - assert EEXE.n_swap_attempts == 2 + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_2 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 1 assert EEXE.n_rejected == 0 assert pattern_2 == [0, 2, 1, 3] # Case 3: Neighboring swap random.seed(0) + EEXE = get_EEXE_instance(params_dict) EEXE.proposal = 'neighboring' # n_ex will be set to 1 automatically. states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (1, 2)], swap = (1, 2), accept - pattern_3 = EEXE.get_swapping_pattern(dhdl_files, states, weights) - assert EEXE.n_swap_attempts == 3 + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_3 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 1 assert EEXE.n_rejected == 0 assert pattern_3 == [0, 2, 1, 3] - # Case 4: Exhaustive swap + # Case 4-1: Exhaustive swaps that end up in a single swap random.seed(0) - EEXE.proposal = 'exhaustive' # n_ex will be set to 1 automatically. - # To be added. :) + EEXE = get_EEXE_instance(params_dict) + EEXE.proposal = 'exhaustive' + states = [5, 2, 2, 8] # swappable pairs: [(0, 1), (0, 2), (1, 2)], swap = (1, 2), accept + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_4_1 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 1 + assert EEXE.n_rejected == 0 + assert pattern_4_1 == [0, 2, 1, 3] + + # Case 4-2: Exhaustive swaps that involve multiple attempted swaps + random.seed(0) + EEXE = get_EEXE_instance(params_dict) + EEXE.proposal = 'exhaustive' + states = [4, 2, 4, 3] # swappable pairs: [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3)]; swap 1: (2, 3), accepted; swap 2: (0, 1), accept # noqa: E501 + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_4_2 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 2 # \Delta is negative for both swaps -> both accepted + assert EEXE.n_rejected == 0 + assert pattern_4_2 == [1, 0, 3, 2] # Case 5-1: Multiple swaps (proposal = 'multiple', n_ex = 5) + print('test 5-1') random.seed(0) + EEXE = get_EEXE_instance(params_dict) EEXE.n_ex = 5 EEXE.proposal = 'multiple' states = [3, 1, 4, 6] # swappable pairs: [(0, 1), (0, 2), (1, 2)], first swap = (1, 2), accept - pattern_5_1 = EEXE.get_swapping_pattern(dhdl_files, states, weights) - assert EEXE.n_swap_attempts == 8 + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_5_1 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 5 assert EEXE.n_rejected == 4 assert pattern_5_1 == [2, 1, 0, 3] - # Case 5-2: Multiple swaps but only one swappable pair (proposal = 'multiple') + # Case 5-2: Multiple swaps but with only one swappable pair (proposal = 'multiple') # This is specifically for testing n_swap_attempts in the case where n_ex reduces to 1. random.seed(0) + EEXE = get_EEXE_instance(params_dict) states = [0, 2, 3, 8] # The only swappable pair is [(1, 2)] --> accept - pattern_5_2 = EEXE.get_swapping_pattern(dhdl_files, states, weights) - assert EEXE.n_swap_attempts == 9 - assert EEXE.n_rejected == 4 + w, f = copy.deepcopy(weights), copy.deepcopy(dhdl_files) + pattern_5_2 = EEXE.get_swapping_pattern(f, states, w) + assert EEXE.n_swap_attempts == 1 # since there is only 1 swappable pair + assert EEXE.n_rejected == 0 assert pattern_5_2 == [0, 2, 1, 3] def test_calc_prob_acc(self, capfd, params_dict):