Skip to content

Commit

Permalink
Merge pull request #58 from wehs7661/pr-57
Browse files Browse the repository at this point in the history
MT-REXEE development
  • Loading branch information
wehs7661 authored Oct 17, 2024
2 parents 9958558 + 98af948 commit 7d5d699
Show file tree
Hide file tree
Showing 49 changed files with 57,786 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: lint

on:
push:
pull_request:

jobs:

Expand Down
2 changes: 2 additions & 0 deletions ensemble_md/cli/run_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def main():
os.mkdir(f'{REXEE.working_dir}/sim_{i}/iteration_0')
MDP = REXEE.initialize_MDP(i)
MDP.write(f"{REXEE.working_dir}/sim_{i}/iteration_0/expanded.mdp", skipempty=True)
if REXEE.modify_coords == 'default' and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
REXEE.process_top()

# 2-2. Run the first set of simulations
REXEE.run_REXEE(0)
Expand Down
166 changes: 156 additions & 10 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import warnings
import importlib
import subprocess
import mdtraj as md
import pandas as pd
import numpy as np
from mpi4py import MPI
from itertools import combinations
Expand All @@ -29,6 +31,7 @@
import ensemble_md
from ensemble_md.utils import utils
from ensemble_md.utils import gmx_parser
from ensemble_md.utils import coordinate_swap
from ensemble_md.utils.exceptions import ParameterError

comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -157,6 +160,8 @@ def set_params(self, analysis):
optional_args = {
"add_swappables": None,
"modify_coords": None,
"resname_list": None,
"swap_rep_pattern": None,
"nst_sim": None,
"proposal": 'exhaustive',
"w_combine": False,
Expand Down Expand Up @@ -254,6 +259,10 @@ def set_params(self, analysis):
raise ParameterError(f"The parameter '{i}' should be a boolean variable.")

params_list = ['add_swappables', 'df_ref']
if self.resname_list is not None:
params_list.append('resname_list')
if self.swap_rep_pattern is not None:
params_list.append('swap_rep_pattern')
for i in params_list:
if getattr(self, i) is not None and not isinstance(getattr(self, i), list):
raise ParameterError(f"The parameter '{i}' should be a list.")
Expand Down Expand Up @@ -441,17 +450,24 @@ def set_params(self, analysis):

# 7-12. External module for coordinate modification
if self.modify_coords is not None:
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
if self.modify_coords == 'default':
if self.swap_rep_pattern is None and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
raise ParameterError('swap_rep_pattern option must be filled in if using default swapping function and not swap guide') # noqa: E501
if self.resname_list is None and (not os.path.exists('residue_connect.csv') or not os.path.exists('residue_swap_map.csv')): # noqa: E501
raise ParameterError('resname_list option must be filled in if using default swapping function and not swap guide') # noqa: E501
self.modify_coords_fn = self.default_coords_fn
else:
self.modify_coords_fn = getattr(module, module_name)
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
else:
self.modify_coords_fn = getattr(module, module_name)
else:
self.modify_coords_fn = None

Expand Down Expand Up @@ -1496,3 +1512,133 @@ def run_REXEE(self, n, swap_pattern=None):
# want it to start parsing the dhdl file (in the if condition of if rank == 0) of simulation 3 being run by
# rank 3 that has not been generated, which will lead to an I/O error.
comm.barrier()

def default_coords_fn(self, molA_file_name, molB_file_name):
"""
Swaps coordinates between two GRO files.
Parameters
----------
molA_file_name : str
GRO file name for the moleucle to be swapped.
molB_file_name : str
GRO file name for the other moleucle to be swapped.
"""
# Determine name for transformed residue
molA_dir = molA_file_name.rsplit('/', 1)[0] + '/'
molB_dir = molB_file_name.rsplit('/', 1)[0] + '/'

# Load trajectory trr for higher precison coordinates
molA = md.load_trr(f'{molA_dir}/traj.trr', top=molA_file_name).slice(-1) # Load last frame of trr trajectory
molB = md.load_trr(f'{molB_dir}/traj.trr', top=molB_file_name).slice(-1)

# Load the coordinate swapping map
connection_map = pd.read_csv('residue_connect.csv')
swap_map = pd.read_csv('residue_swap_map.csv')

# Step 1: Read the GRO input coordinate files and open temporary Output files
molA_file = open(molA_file_name, 'r').readlines() # open input file
molB_new_file_name = 'B_hybrid_swap.gro'
molB_new = open(molB_new_file_name, 'w')
molB_file = open(molB_file_name, 'r').readlines() # open input file
molA_new_file_name = 'A_hybrid_swap.gro'
molA_new = open(molA_new_file_name, 'w')

# Step 2: Determine atoms for alignment and swapping
residue_options = swap_map['Swap A'].to_list() + swap_map['Swap B'].to_list()
nameA = coordinate_swap.identify_res(molA.topology, residue_options)
nameB = coordinate_swap.identify_res(molB.topology, residue_options)
df_atom_swap = coordinate_swap.find_common(molA_file, molB_file, nameA, nameB)

# Step 3: Fix break if present for solvated systems only
if len(molA.topology.select('water')) != 0:
A_dimensions = coordinate_swap.get_dimensions(molA_file)
B_dimensions = coordinate_swap.get_dimensions(molB_file)
molA = coordinate_swap.fix_break(molA, nameA, A_dimensions, connection_map[connection_map['Resname'] == nameA]) # noqa: E501
molB = coordinate_swap.fix_break(molB, nameB, B_dimensions, connection_map[connection_map['Resname'] == nameB]) # noqa: E501

# Step 4: Determine coordinates of atoms which need to be reconstructed as we swap coordinates between molecules # noqa: E501
miss_B = df_atom_swap[(df_atom_swap['Swap'] == 'B2A') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
miss_A = df_atom_swap[(df_atom_swap['Swap'] == 'A2B') & (df_atom_swap['Direction'] == 'miss')]['Name'].to_list() # noqa: E501
if len(miss_B) != 0:
df_atom_swap = coordinate_swap.get_miss_coord(molB, molA, nameB, nameA, df_atom_swap, 'B2A', swap_map[(swap_map['Swap A'] == nameB) & (swap_map['Swap B'] == nameA)]) # noqa: E501
if len(miss_A) != 0:
df_atom_swap = coordinate_swap.get_miss_coord(molA, molB, nameA, nameB, df_atom_swap, 'A2B', swap_map[(swap_map['Swap A'] == nameA) & (swap_map['Swap B'] == nameB)]) # noqa: E501

# Step 5: Parse Current file to ensure atoms are added in the correct order
atom_order_A = gmx_parser.deter_atom_order(molA_file, nameA)
atom_order_B = gmx_parser.deter_atom_order(molB_file, nameB)

# Step 6: Write the new file
# Reprint preamble text
line_start = coordinate_swap.print_preamble(molA_file, molB_new, len(miss_B), len(miss_A))

# Print new coordinates to file for molB
coordinate_swap.write_new_file(df_atom_swap, 'A2B', 'B2A', line_start, molA_file, molB_new, nameA, nameB, copy.deepcopy(molA.xyz[0]), miss_A, atom_order_B) # noqa: E501

# Print new coordinates to file
# Reprint preamble text
line_start = coordinate_swap.print_preamble(molB_file, molA_new, len(miss_A), len(miss_B))

# Print new coordinates for molA
coordinate_swap.write_new_file(df_atom_swap, 'B2A', 'A2B', line_start, molB_file, molA_new, nameB, nameA, copy.deepcopy(molB.xyz[0]), miss_B, atom_order_A) # noqa: E501

# Rename temp files
os.rename('A_hybrid_swap.gro', molB_dir + '/confout.gro')
os.rename('B_hybrid_swap.gro', molA_dir + '/confout.gro')

def process_top(self):
"""
Processes the input topologies in order to determine the atoms for alignment in the default GRO swapping
function. Output as csv files to prevent needing to re-run this step.
"""
if not os.path.exists('residue_connect.csv'):
df_top = pd.DataFrame()
for f, file_name in enumerate(self.top):
# Read file
input_file = gmx_parser.read_top(file_name, self.resname_list[f])

# Determine the atom names corresponding to the atom numbers
start_line, atom_name, state = coordinate_swap.get_names(input_file)

# Determine the connectivity of all atoms
connect_1, connect_2, state_1, state_2 = [], [], [], [] # Atom 1 and atom 2 which are connected and which state they are dummy atoms # noqa: E501
for l, line in enumerate(input_file[start_line:]): # noqa: E741
line_sep = line.split(' ')
if line_sep[0] == ';':
continue
if line_sep[0] == '\n':
break
while '' in line_sep:
line_sep.remove('')
connect_1.append(atom_name[int(line_sep[0])-1])
connect_2.append(atom_name[int(line_sep[1])-1])
state_1.append(state[int(line_sep[0])-1])
state_2.append(state[int(line_sep[1])-1])
df = pd.DataFrame({'Resname': self.resname_list[f], 'Connect 1': connect_1, 'Connect 2': connect_2, 'State 1': state_1, 'State 2': state_2}) # noqa: E501
df_top = pd.concat([df_top, df])
df_top.to_csv('residue_connect.csv')
else:
df_top = pd.read_csv('residue_connect.csv')

if not os.path.exists('residue_swap_map.csv'):
df_map = pd.DataFrame()
for swap in self.swap_rep_pattern:
# Determine atoms not present in both molecules
X, Y = [int(swap[0][0]), int(swap[1][0])]
lam = {X: int(swap[0][1]), Y: int(swap[1][1])}
for A, B in zip([X, Y], [Y, X]):
input_A = gmx_parser.read_top(self.top[A], self.resname_list[A])
start_line, A_name, state = coordinate_swap.get_names(input_A)
input_B = gmx_parser.read_top(self.top[B], self.resname_list[B])
start_line, B_name, state = coordinate_swap.get_names(input_B)

A_only = [x for x in A_name if x not in B_name]
B_only = [x for x in B_name if x not in A_name]

# Seperate real to dummy switches
df = coordinate_swap.determine_connection(A_only, B_only, self.resname_list[A], self.resname_list[B], df_top, lam[A]) # noqa: E501

df_map = pd.concat([df_map, df])

df_map.to_csv('residue_swap_map.csv')
Loading

0 comments on commit 7d5d699

Please sign in to comment.