Skip to content

Commit

Permalink
Merging of equivalent molecules is now optional, but still defaults t…
Browse files Browse the repository at this point in the history
…o true.
  • Loading branch information
ndickson-nvidia committed Jul 24, 2024
1 parent 35f456a commit f678911
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 30 deletions.
3 changes: 3 additions & 0 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,12 @@ def encode_feature_options(options, name, encoding_function):

explicit_H = featurization["explicit_H"] if "explicit_H" in featurization else False
add_self_loop = featurization["add_self_loop"] if "add_self_loop" in featurization else False
merge_equivalent_mols = featurization["merge_equivalent_mols"] if "merge_equivalent_mols" in featurization else True

# Save these for calling graphium_cpp.prepare_and_save_data later
self.add_self_loop = add_self_loop
self.explicit_H = explicit_H
self.merge_equivalent_mols = merge_equivalent_mols

self.preprocessing_n_jobs = preprocessing_n_jobs

Expand Down Expand Up @@ -1137,6 +1139,7 @@ def prepare_data(self):
self.add_self_loop,
self.explicit_H,
self.preprocessing_n_jobs,
self.merge_equivalent_mols,
)

for task, stats in all_stats.items():
Expand Down
74 changes: 45 additions & 29 deletions graphium/graphium_cpp/labels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,33 +158,38 @@ struct MolBriefData {

static MolBriefData smiles_to_brief_data(
const std::string& smiles_string,
bool add_self_loop = false,
bool explicit_H = false) {
bool add_self_loop,
bool explicit_H,
bool compute_inchi_key) {

// Don't add explicit_H here, in case it affects MolToInchiKey (though it really shouldn't)
std::unique_ptr<RDKit::RWMol> mol{ parse_mol(smiles_string, false) };
if (!mol) {
return MolBriefData{ {0,0}, 0, 0 };
}

const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T");
size_t n = inchiKeyString.size();
// Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P
// According to https://www.inchi-trust.org/technical-faq/
assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-');
// Convert from capital letter characters to 64-bit integers:
// 13 characters for first integer, 12 characters for 2nd integer.
// Neither should overflow a 64-bit unsigned integer.
uint64_t id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0;
for (size_t i = 1; i < 13 && i < n; ++i) {
id0 = 26*id0 + (inchiKeyString[i] - 'A');
}
uint64_t id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0;
for (size_t i = 15; i < 25 && i < n; ++i) {
id1 = 26*id1 + (inchiKeyString[i] - 'A');
}
if (26 < n) {
id1 = 26*id1 + (inchiKeyString[26] - 'A');
uint64_t id0 = 0;
uint64_t id1 = 0;
if (compute_inchi_key) {
const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T");
size_t n = inchiKeyString.size();
// Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P
// According to https://www.inchi-trust.org/technical-faq/
assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-');
// Convert from capital letter characters to 64-bit integers:
// 13 characters for first integer, 12 characters for 2nd integer.
// Neither should overflow a 64-bit unsigned integer.
uint64_t id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0;
for (size_t i = 1; i < 13 && i < n; ++i) {
id0 = 26*id0 + (inchiKeyString[i] - 'A');
}
uint64_t id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0;
for (size_t i = 15; i < 25 && i < n; ++i) {
id1 = 26*id1 + (inchiKeyString[i] - 'A');
}
if (26 < n) {
id1 = 26*id1 + (inchiKeyString[26] - 'A');
}
}

// Now handle explicit_H
Expand Down Expand Up @@ -639,7 +644,8 @@ std::tuple<
const pybind11::dict& task_test_indices,
bool add_self_loop,
bool explicit_H,
int max_threads) {
int max_threads,
bool merge_equivalent_mols) {

ensure_numpy_array_module_initialized();

Expand Down Expand Up @@ -989,7 +995,7 @@ std::tuple<
num_threads = size_t(max_threads);
}

auto&& get_single_mol_key = [&task_mol_start,add_self_loop,explicit_H,&task_mol_indices,&smiles_strings,num_tasks](size_t mol_index) -> MolKey {
auto&& get_single_mol_key = [&task_mol_start,add_self_loop,explicit_H,&task_mol_indices,&smiles_strings,num_tasks,merge_equivalent_mols](size_t mol_index) -> MolKey {
// Find which task this mol is in. If there could be many tasks,
// this could be a binary search, but for small numbers of tasks,
// a linear search is fine.
Expand All @@ -1000,7 +1006,14 @@ std::tuple<
const size_t task_mol_index = task_mol_indices[mol_index];

const std::string& smiles_str = smiles_strings[mol_index];
MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H);
MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H, merge_equivalent_mols);

if (!merge_equivalent_mols) {
// mol_index is, by definition, distinct for each input index,
// so no molecules will be identified as equivalent below.
mol_data.unique_id[0] = mol_index;
mol_data.unique_id[1] = 0;
}

return MolKey{mol_data.unique_id[0], mol_data.unique_id[1], mol_data.num_nodes, mol_data.num_edges, task_index % num_tasks, task_mol_index, mol_index};
};
Expand Down Expand Up @@ -1209,12 +1222,14 @@ std::tuple<
}
}

// Sort train, val, and test separately, since they need to be stored separately.
// Don't sort until after accumulating stats, because the code above currently assumes that the tasks
// aren't interleaved.
std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]);
std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]);
std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols);
if (merge_equivalent_mols) {
// Sort train, val, and test separately, since they need to be stored separately.
// Don't sort until after accumulating stats, because the code above currently assumes that the tasks
// aren't interleaved.
std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]);
std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]);
std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols);
}

std::unordered_map<std::string, std::vector<at::Tensor>> per_stage_return_data;

Expand Down Expand Up @@ -1364,6 +1379,7 @@ std::tuple<
++sorted_index;
}
assert(mol_num_tasks <= num_tasks);
assert(!merge_equivalent_mols || mol_num_tasks == 1);

// TODO: Double data capacity as needed if resizing is slow
assert(data.size() == data_offset);
Expand Down
3 changes: 2 additions & 1 deletion graphium/graphium_cpp/labels.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ std::tuple<
const pybind11::dict& task_test_indices,
bool add_self_loop = false,
bool explicit_H = false,
int max_threads = 0);
int max_threads = 0,
bool merge_equivalent_mols = true);

void load_labels_from_index(
const std::string stage_directory,
Expand Down

0 comments on commit f678911

Please sign in to comment.