diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index e2368d3b3..30d8e42ae 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -917,10 +917,12 @@ def encode_feature_options(options, name, encoding_function): explicit_H = featurization["explicit_H"] if "explicit_H" in featurization else False add_self_loop = featurization["add_self_loop"] if "add_self_loop" in featurization else False + merge_equivalent_mols = featurization["merge_equivalent_mols"] if "merge_equivalent_mols" in featurization else True # Save these for calling graphium_cpp.prepare_and_save_data later self.add_self_loop = add_self_loop self.explicit_H = explicit_H + self.merge_equivalent_mols = merge_equivalent_mols self.preprocessing_n_jobs = preprocessing_n_jobs @@ -1137,6 +1139,7 @@ def prepare_data(self): self.add_self_loop, self.explicit_H, self.preprocessing_n_jobs, + self.merge_equivalent_mols, ) for task, stats in all_stats.items(): diff --git a/graphium/graphium_cpp/labels.cpp b/graphium/graphium_cpp/labels.cpp index db0ccfa6a..43660e59c 100644 --- a/graphium/graphium_cpp/labels.cpp +++ b/graphium/graphium_cpp/labels.cpp @@ -158,8 +158,9 @@ struct MolBriefData { static MolBriefData smiles_to_brief_data( const std::string& smiles_string, - bool add_self_loop = false, - bool explicit_H = false) { + bool add_self_loop, + bool explicit_H, + bool compute_inchi_key) { // Don't add explicit_H here, in case it affects MolToInchiKey (though it really shouldn't) std::unique_ptr mol{ parse_mol(smiles_string, false) }; @@ -167,24 +168,28 @@ static MolBriefData smiles_to_brief_data( return MolBriefData{ {0,0}, 0, 0 }; } - const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T"); - size_t n = inchiKeyString.size(); - // Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P - // According to https://www.inchi-trust.org/technical-faq/ - assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-'); - // Convert from capital letter characters to 64-bit integers: - // 13 characters for first integer, 12 characters for 2nd integer. - // Neither should overflow a 64-bit unsigned integer. - uint64_t id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0; - for (size_t i = 1; i < 13 && i < n; ++i) { - id0 = 26*id0 + (inchiKeyString[i] - 'A'); - } - uint64_t id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0; - for (size_t i = 15; i < 25 && i < n; ++i) { - id1 = 26*id1 + (inchiKeyString[i] - 'A'); - } - if (26 < n) { - id1 = 26*id1 + (inchiKeyString[26] - 'A'); + uint64_t id0 = 0; + uint64_t id1 = 0; + if (compute_inchi_key) { + const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T"); + size_t n = inchiKeyString.size(); + // Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P + // According to https://www.inchi-trust.org/technical-faq/ + assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-'); + // Convert from capital letter characters to 64-bit integers: + // 13 characters for first integer, 12 characters for 2nd integer. + // Neither should overflow a 64-bit unsigned integer. + uint64_t id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0; + for (size_t i = 1; i < 13 && i < n; ++i) { + id0 = 26*id0 + (inchiKeyString[i] - 'A'); + } + uint64_t id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0; + for (size_t i = 15; i < 25 && i < n; ++i) { + id1 = 26*id1 + (inchiKeyString[i] - 'A'); + } + if (26 < n) { + id1 = 26*id1 + (inchiKeyString[26] - 'A'); + } } // Now handle explicit_H @@ -639,7 +644,8 @@ std::tuple< const pybind11::dict& task_test_indices, bool add_self_loop, bool explicit_H, - int max_threads) { + int max_threads, + bool merge_equivalent_mols) { ensure_numpy_array_module_initialized(); @@ -989,7 +995,7 @@ std::tuple< num_threads = size_t(max_threads); } - auto&& get_single_mol_key = [&task_mol_start,add_self_loop,explicit_H,&task_mol_indices,&smiles_strings,num_tasks](size_t mol_index) -> MolKey { + auto&& get_single_mol_key = [&task_mol_start,add_self_loop,explicit_H,&task_mol_indices,&smiles_strings,num_tasks,merge_equivalent_mols](size_t mol_index) -> MolKey { // Find which task this mol is in. If there could be many tasks, // this could be a binary search, but for small numbers of tasks, // a linear search is fine. @@ -1000,7 +1006,14 @@ std::tuple< const size_t task_mol_index = task_mol_indices[mol_index]; const std::string& smiles_str = smiles_strings[mol_index]; - MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H); + MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H, merge_equivalent_mols); + + if (!merge_equivalent_mols) { + // mol_index is, by definition, distinct for each input index, + // so no molecules will be identified as equivalent below. + mol_data.unique_id[0] = mol_index; + mol_data.unique_id[1] = 0; + } return MolKey{mol_data.unique_id[0], mol_data.unique_id[1], mol_data.num_nodes, mol_data.num_edges, task_index % num_tasks, task_mol_index, mol_index}; }; @@ -1209,12 +1222,14 @@ std::tuple< } } - // Sort train, val, and test separately, since they need to be stored separately. - // Don't sort until after accumulating stats, because the code above currently assumes that the tasks - // aren't interleaved. - std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]); - std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]); - std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols); + if (merge_equivalent_mols) { + // Sort train, val, and test separately, since they need to be stored separately. + // Don't sort until after accumulating stats, because the code above currently assumes that the tasks + // aren't interleaved. + std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]); + std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]); + std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols); + } std::unordered_map> per_stage_return_data; @@ -1364,6 +1379,7 @@ std::tuple< ++sorted_index; } assert(mol_num_tasks <= num_tasks); + assert(!merge_equivalent_mols || mol_num_tasks == 1); // TODO: Double data capacity as needed if resizing is slow assert(data.size() == data_offset); diff --git a/graphium/graphium_cpp/labels.h b/graphium/graphium_cpp/labels.h index 30498750d..3978686a5 100644 --- a/graphium/graphium_cpp/labels.h +++ b/graphium/graphium_cpp/labels.h @@ -55,7 +55,8 @@ std::tuple< const pybind11::dict& task_test_indices, bool add_self_loop = false, bool explicit_H = false, - int max_threads = 0); + int max_threads = 0, + bool merge_equivalent_mols = true); void load_labels_from_index( const std::string stage_directory,