forked from MolFilterGAN/MolFilterGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Utils.py
146 lines (135 loc) · 5.28 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import rdkit
from rdkit import Chem
import re
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
file_path = os.path.abspath(os.path.dirname(__file__))
def replace_halogen(string):
"""Regex to replace Br and Cl with single letters"""
br = re.compile('Br')
cl = re.compile('Cl')
string = br.sub('R', string)
string = cl.sub('L', string)
return string
def construct_voc(smiles_list, min_str_freq=1):
"""Returns all the characters present in a SMILES file.
Uses regex to find characters/tokens of the format '[x]'."""
add_chars = {}
for i, smiles in enumerate(smiles_list):
regex = '(\[[^\[\]]{1,10}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
for char in char_list:
if char.startswith('['):
if char not in add_chars:
add_chars[char] = 1
else:
add_chars[char] += 1
else:
chars = [unit for unit in char]
for unit in chars:
if unit not in add_chars:
add_chars[unit] = 1
else:
add_chars[unit] += 1
print("Number of characters: {}".format(len(add_chars)))
res = sorted(add_chars.items(), key=lambda add_chars: add_chars[1], reverse=True)
print(res)
voc_ls = []
less_ls = []
for i in res:
if i[1] > min_str_freq:
voc_ls.append(i[0])
else:
less_ls.append(i[0])
# with open(os.path.join(file_path,'Voc'), 'w') as f:
# for char in voc_ls:
# f.write(char + "\n")
return voc_ls, less_ls
def rm_voc_less(smiles_list, voc_ls):
smiles_list_final = []
for smiles in smiles_list:
regex = '(\[[^\[\]]{1,10}\])'
smiles_ = replace_halogen(smiles)
char_list = re.split(regex, smiles_)
label = True
for char in char_list:
if char.startswith('['):
if char not in voc_ls:
label = False
print(smiles)
break
else:
chars = [unit for unit in char]
for unit in chars:
if unit not in voc_ls:
label = False
print(smiles)
break
if label:
smiles_list_final.append(smiles)
return smiles_list_final
class Vocabulary(object):
"""A class for handling encoding/decoding from SMILES to an array of indices"""
def __init__(self, init_from_file=None, max_length=140):
self.special_tokens = ['PAD', 'GO', 'EOS']
self.additional_chars = set()
self.chars = self.special_tokens
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
self.max_length = max_length
if init_from_file: self.init_from_file(init_from_file)
def encode(self, char_list):
"""Takes a list of characters (eg '[NH]') and encodes to array of indices"""
smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
for i, char in enumerate(char_list):
smiles_matrix[i] = self.vocab[char]
return smiles_matrix
def decode(self, matrix):
"""Takes an array of indices and returns the corresponding SMILES"""
chars = []
for i in matrix:
if (i == self.vocab['EOS']) or (i == self.vocab['PAD']): break
chars.append(self.reversed_vocab[i])
smiles = "".join(chars)
smiles = smiles.replace("L", "Cl").replace("R", "Br")
return smiles
def tokenize(self, smiles):
"""Takes a SMILES and return a list of characters/tokens"""
regex = '(\[[^\[\]]{1,10}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
tokenized = []
tokenized.append('GO')
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
tokenized.append('EOS')
return tokenized
def add_characters(self, chars):
"""Adds characters to the vocabulary"""
for char in chars:
self.additional_chars.add(char)
char_list = list(self.additional_chars)
char_list.sort()
self.chars = self.special_tokens + char_list
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
def init_from_file(self, file):
"""Takes a file containing \n separated characters to initialize the vocabulary"""
with open(file, 'r') as f:
chars = f.read().split()
self.add_characters(chars)
def __len__(self):
return len(self.chars)
def __str__(self):
return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)