diff --git a/.gitignore b/.gitignore index 1de96b7..a6f66b6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,10 @@ rcsbsearch/ .pytest_cache/ *ipynb ./*zip -pulchra304 tmp/ all_structures/ -all_structures.zip \ No newline at end of file +all_structures.zip +esmfold_output/ +igfold_output/ +*.csv +run_metrics.py \ No newline at end of file diff --git a/proteinflow/split/__init__.py b/proteinflow/split/__init__.py index 3cadc4a..a14e9a2 100644 --- a/proteinflow/split/__init__.py +++ b/proteinflow/split/__init__.py @@ -1430,18 +1430,38 @@ def _split_data( if len(excluded_files) > 0: set_to_exclude = set(excluded_files) excluded_files = set() - excluded_clusters_dict = defaultdict(set) - if exclude_clusters: - for clusters_dict in [ - train_clusters_dict, - valid_clusters_dict, - test_clusters_dict, - ]: - subset_excluded_set, subset_excluded_dict = _exclude( - clusters_dict, set_to_exclude, exclude_based_on_cdr - ) - excluded_files.update(subset_excluded_set) - excluded_clusters_dict.update(subset_excluded_dict) + excluded_clusters_dict = defaultdict(list) + for clusters_dict in [ + train_clusters_dict, + valid_clusters_dict, + test_clusters_dict, + ]: + for cluster in list(clusters_dict.keys()): + idx_to_exclude = [] + exclude_whole_cluster = False + for i, chain in enumerate(clusters_dict[cluster]): + if chain[0] in set_to_exclude: + if exclude_clusters: + if exclude_based_on_cdr is not None and cluster.endswith( + exclude_based_on_cdr + ): + exclude_whole_cluster = True + elif exclude_based_on_cdr is None: + exclude_whole_cluster = True + if exclude_whole_cluster: + break + excluded_clusters_dict[cluster].append(chain) + idx_to_exclude.append(i) + if exclude_whole_cluster: + excluded_clusters_dict[cluster] = clusters_dict.pop(cluster) + else: + clusters_dict[cluster] = [ + x + for i, x in enumerate(clusters_dict[cluster]) + if i not in idx_to_exclude + ] + if len(clusters_dict[cluster]) == 0: + clusters_dict.pop(cluster) excluded_files.update(set_to_exclude) excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()} excluded_path = os.path.join(dataset_path, "excluded") @@ -1450,6 +1470,10 @@ def _split_data( print("Updating the split dictionaries...") with open(os.path.join(dict_folder, "train.pickle"), "wb") as f: pickle.dump(train_clusters_dict, f) + with open(os.path.join(dict_folder, "valid.pickle"), "wb") as f: + pickle.dump(valid_clusters_dict, f) + with open(os.path.join(dict_folder, "test.pickle"), "wb") as f: + pickle.dump(test_clusters_dict, f) with open(os.path.join(dict_folder, "excluded.pickle"), "wb") as f: pickle.dump(excluded_clusters_dict, f) print("Moving excluded files...") diff --git a/proteinflow/split/utils.py b/proteinflow/split/utils.py index 996fbea..2f9408b 100644 --- a/proteinflow/split/utils.py +++ b/proteinflow/split/utils.py @@ -184,7 +184,7 @@ def _exclude(clusters_dict, set_to_exclude, exclude_based_on_cdr=None): files = clusters_dict[cluster] exclude = False for biounit in files: - if biounit in set_to_exclude: + if biounit[0] in set_to_exclude: exclude = True break if exclude: