-
Notifications
You must be signed in to change notification settings - Fork 30
/
evaluate.py
158 lines (133 loc) · 6.93 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import json
import argparse
import numpy as np
import multiprocessing
import pandas as pd
import rdkit
from rdkit import Chem, DataStructs
rdkit.RDLogger.DisableLog('rdApp.*')
from SmilesPE.pretokenizer import atomwise_tokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gold_file', type=str, required=True)
parser.add_argument('--pred_file', type=str, required=True)
parser.add_argument('--pred_field', type=str, default='SMILES')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--tanimoto', action='store_true')
parser.add_argument('--keep_main', action='store_true')
args = parser.parse_args()
return args
def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True):
if type(smiles) is not str or smiles == '':
return '', False
if ignore_cistrans:
smiles = smiles.replace('/', '').replace('\\', '')
if replace_rgroup:
tokens = atomwise_tokenizer(smiles)
for j, token in enumerate(tokens):
if token[0] == '[' and token[-1] == ']':
symbol = token[1:-1]
if symbol[0] == 'R' and symbol[1:].isdigit():
tokens[j] = f'[{symbol[1:]}*]'
elif Chem.AtomFromSmiles(token) is None:
tokens[j] = '*'
smiles = ''.join(tokens)
try:
canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral))
success = True
except:
canon_smiles = smiles
success = False
return canon_smiles, success
def convert_smiles_to_canonsmiles(
smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16):
with multiprocessing.Pool(num_workers) as p:
results = p.starmap(canonicalize_smiles,
[(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list],
chunksize=128)
canon_smiles, success = zip(*results)
return list(canon_smiles), np.mean(success)
def _keep_main_molecule(smiles, debug=False):
try:
mol = Chem.MolFromSmiles(smiles)
frags = Chem.GetMolFrags(mol, asMols=True)
if len(frags) > 1:
num_atoms = [m.GetNumAtoms() for m in frags]
main_mol = frags[np.argmax(num_atoms)]
smiles = Chem.MolToSmiles(main_mol)
except Exception as e:
pass
return smiles
def keep_main_molecule(smiles, num_workers=16):
with multiprocessing.Pool(num_workers) as p:
results = p.map(_keep_main_molecule, smiles, chunksize=128)
return results
def tanimoto_similarity(smiles1, smiles2):
try:
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)
fp1 = Chem.RDKFingerprint(mol1)
fp2 = Chem.RDKFingerprint(mol2)
tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2)
return tanimoto
except:
return 0
def compute_tanimoto_similarities(gold_smiles, pred_smiles, num_workers=16):
with multiprocessing.Pool(num_workers) as p:
similarities = p.starmap(tanimoto_similarity, [(gs, ps) for gs, ps in zip(gold_smiles, pred_smiles)])
return similarities
class SmilesEvaluator(object):
def __init__(self, gold_smiles, num_workers=16, tanimoto=False):
self.gold_smiles = gold_smiles
self.num_workers = num_workers
self.tanimoto = tanimoto
self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles,
ignore_cistrans=True,
num_workers=num_workers)
self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles,
ignore_chiral=True, ignore_cistrans=True,
num_workers=num_workers)
self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans)
self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral)
def _replace_empty(self, smiles_list):
"""Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty."""
return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>"
for smiles in smiles_list]
def evaluate(self, pred_smiles, include_details=False):
results = {}
if self.tanimoto:
results['tanimoto'] = np.mean(compute_tanimoto_similarities(self.gold_smiles, pred_smiles))
# Ignore double bond cis/trans
pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles,
ignore_cistrans=True,
num_workers=self.num_workers)
results['canon_smiles'] = np.mean(np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans))
if include_details:
results['canon_smiles_details'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans))
# Ignore chirality (Graph exact match)
pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles,
ignore_chiral=True, ignore_cistrans=True,
num_workers=self.num_workers)
results['graph'] = np.mean(np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral))
# Evaluate on molecules with chiral centers
chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g])
results['chiral'] = np.mean(chiral[:, 0] == chiral[:, 1]) if len(chiral) > 0 else -1
return results
if __name__ == "__main__":
args = get_args()
gold_df = pd.read_csv(args.gold_file)
pred_df = pd.read_csv(args.pred_file)
if len(pred_df) != len(gold_df):
print(f"Pred ({len(pred_df)}) and Gold ({len(gold_df)}) have different lengths!")
# Re-order pred_df to have the same order with gold_df
image2goldidx = {image_id: idx for idx, image_id in enumerate(gold_df['image_id'])}
image2predidx = {image_id: idx for idx, image_id in enumerate(pred_df['image_id'])}
for image_id in gold_df['image_id']:
# If image_id doesn't exist in pred_df, add an empty prediction.
if image_id not in image2predidx:
pred_df = pred_df.append({'image_id': image_id, args.pred_field: ""}, ignore_index=True)
image2predidx = {image_id: idx for idx, image_id in enumerate(pred_df['image_id'])}
pred_df = pred_df.reindex([image2predidx[image_id] for image_id in gold_df['image_id']])
evaluator = SmilesEvaluator(gold_df['SMILES'], args.num_workers, args.tanimoto)
scores = evaluator.evaluate(pred_df[args.pred_field])
print(json.dumps(scores, indent=4))