Skip to content

Commit

Permalink
Update tests to match changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Nov 4, 2024
1 parent 48ddbbb commit 1af9cd5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
3 changes: 0 additions & 3 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,6 @@ def process_top(self):

# Determine the atom names corresponding to the atom numbers
start_line, atom_name, atom_num, state = coordinate_swap.get_names(input_file, self.resname_list[f])
print(f'start line #: {start_line}')
print(f'atom name: {atom_name}')
print(f'atom #: {atom_num}')

# Determine the connectivity of all atoms
connect_1, connect_2, state_1, state_2 = [], [], [], [] # Atom 1 and atom 2 which are connected and which state they are dummy atoms # noqa: E501
Expand Down
3 changes: 2 additions & 1 deletion ensemble_md/tests/test_coordinate_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def test_swap_name():

def test_get_names():
top_files = ['A-B.itp', 'B-C.itp', 'C-D.itp', 'D-E.itp', 'E-F.itp']
resnames = ['A2B', 'B2C', 'C2D', 'D2E', 'E2F']

start_lines = [26, 29, 33, 32, 36]
names = [['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'H1', 'H2', 'H3', 'H4', 'H17', 'DC7', 'HV5', 'HV6', 'HV7'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'H1', 'H2', 'H3', 'H4', 'H5', 'H6', 'H7', 'DC8', 'HV8', 'HV9', 'HV10'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C8', 'H1', 'H2', 'H3', 'H4', 'H6', 'H7', 'H8', 'H9', 'H10', 'DC9', 'HV5', 'HV11', 'HV12', 'HV13'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C9', 'H1', 'H2', 'H3', 'H5', 'H6', 'H7', 'H11', 'H12', 'H13', 'DC8', 'HV8', 'HV9', 'HV10'], ['S1', 'C2', 'N3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'H1', 'H2', 'H3', 'H6', 'H7', 'H8', 'H9', 'H10', 'H11', 'H12', 'H13', 'DC10', 'HV4', 'HV14', 'HV15', 'HV16']] # noqa: E501
Expand All @@ -357,7 +358,7 @@ def test_get_names():
[-1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, 0, 0, 0, 0, 0]]
for i, top_file in enumerate(top_files):
top = open(f'{input_path}/coord_swap/{top_file}', 'r').readlines()
test_start_line, test_names, test_lambda_states = coordinate_swap.get_names(top)
test_start_line, test_names, test_nums, test_lambda_states = coordinate_swap.get_names(top, resnames[i])
assert test_start_line == start_lines[i]
assert test_names == names[i]
assert test_lambda_states == lambda_states[i]
Expand Down
46 changes: 31 additions & 15 deletions ensemble_md/utils/coordinate_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,16 +1099,26 @@ def _swap_name(init_atom_names, new_resname, df_top):
continue
atom_num = re.findall(r'[0-9]+', atom)[0]
atom_identifier = re.findall(r'[a-zA-Z]+', atom)[0]
if 'D' in atom_identifier:
atom_identifier = atom_identifier.strip('D')
if 'V' in atom_identifier:
atom_identifier = atom_identifier.strip('V')
if f'{atom_identifier}V{atom_num}' in new_names:
new_atom_names.append(f'{atom_identifier}V{atom_num}')
elif f'{atom_identifier}{atom_num}' in new_names:
new_atom_names.append(f'{atom_identifier}{atom_num}')
elif f'D{atom_identifier}{atom_num}' in new_names:
new_atom_names.append(f'D{atom_identifier}{atom_num}')
if list(atom_identifier)[0] == 'D':
element = list(atom_identifier)[1]
if len(list(atom_identifier)) > 2:
extra = ''.join(list(atom_identifier)[2:])
else:
extra = ''
else:
element = list(atom_identifier)[0]
if len(list(atom_identifier)) > 1:
extra = ''.join(list(atom_identifier)[1:])
else:
extra = ''
if 'V' in extra:
extra = extra.strip('V')
if f'{element}V{extra}{atom_num}' in new_names:
new_atom_names.append(f'{element}V{extra}{atom_num}')
elif f'{element}{extra}{atom_num}' in new_names:
new_atom_names.append(f'{element}{extra}{atom_num}')
elif f'D{element}{extra}{atom_num}' in new_names:
new_atom_names.append(f'D{element}{extra}{atom_num}')
else:
raise Exception(f'Compatible atom could not be found for {atom}')
return new_atom_names
Expand Down Expand Up @@ -1192,17 +1202,23 @@ def determine_connection(main_only, other_only, main_name, other_name, df_top, m
miss, D2R, R2D = [], [], []
align_atom, angle_atom = [], []
for atom in main_only:
element = atom.strip('0123456789')
e_split = list(element)
raw_element = atom.strip('0123456789')
e_split = list(raw_element)
if e_split[0] == 'D':
real_element = ''.join(e_split[1:])
elif len(e_split) > 2 and e_split[1] == 'V':
del e_split[1]
real_element = ''.join(e_split)
else:
real_element = element
real_element = raw_element
if len(real_element) != 1:
element = list(real_element)[0]
identifier = ''.join(list(real_element)[1:])
else:
element = real_element
identifier = ''
num = atom.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
if f'D{atom}' in other_only or f'{element}V{num}' in other_only:
if f'D{atom}' in other_only or f'{element}V{identifier}{num}' in other_only:
D2R.append(atom)
elif f'{real_element}{num}' in other_only:
R2D.append(atom)
Expand All @@ -1223,7 +1239,7 @@ def determine_connection(main_only, other_only, main_name, other_name, df_top, m

# If the atom connects to non-missing atoms than keep these as anchors
for a in connected_atoms:
if a not in miss:
if a not in miss and a not in anchor_atoms:
anchor_atoms.append(a)

# Seperate missing atoms connected to each anchor
Expand Down

0 comments on commit 1af9cd5

Please sign in to comment.