diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index 94ae7bb..4509910 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -530,7 +530,7 @@ def split_data( warnings.warn( "Found an existing dictionary for excluded chains. proteinflow will load it and ignore the exclusion parameters! Run with --ignore_existing to overwrite the splitting." ) - excluded_biounits = [] + excluded_biounits = None else: if exclude_chains_file is not None or exclude_chains is not None: excluded_biounits = _get_excluded_files( diff --git a/proteinflow/split/__init__.py b/proteinflow/split/__init__.py index 5be1515..28a5fec 100644 --- a/proteinflow/split/__init__.py +++ b/proteinflow/split/__init__.py @@ -1400,7 +1400,14 @@ def _split_data( """ if excluded_biounits is None: - excluded_biounits = [] + if os.path.exists(os.path.join(dataset_path, "splits_dict", "excluded.pickle")): + with open( + os.path.join(dataset_path, "splits_dict", "excluded.pickle"), "rb" + ) as f: + excluded_clusters_dict = pickle.load(f) + excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, []) + else: + excluded_biounits = [] dict_folder = os.path.join(dataset_path, "splits_dict") with open(os.path.join(dict_folder, "train.pickle"), "rb") as f: @@ -1428,42 +1435,48 @@ def _split_data( os.makedirs(test_path) if len(excluded_biounits) > 0: - set_to_exclude = set(excluded_biounits) - excluded_biounits = set() - excluded_clusters_dict = defaultdict(list) - for clusters_dict in [ - train_clusters_dict, - valid_clusters_dict, - test_clusters_dict, - ]: - for cluster in list(clusters_dict.keys()): - idx_to_exclude = [] - exclude_whole_cluster = False - for i, chain in enumerate(clusters_dict[cluster]): - if chain[0] in set_to_exclude: - if exclude_clusters: - if exclude_based_on_cdr is not None and cluster.endswith( - exclude_based_on_cdr - ): - exclude_whole_cluster = True - elif exclude_based_on_cdr is None: - exclude_whole_cluster = True - if exclude_whole_cluster: - break - excluded_clusters_dict[cluster].append(chain) - idx_to_exclude.append(i) - if exclude_whole_cluster: - excluded_clusters_dict[cluster] = clusters_dict.pop(cluster) - else: - clusters_dict[cluster] = [ - x - for i, x in enumerate(clusters_dict[cluster]) - if i not in idx_to_exclude - ] - if len(clusters_dict[cluster]) == 0: - clusters_dict.pop(cluster) - excluded_biounits.update(set_to_exclude) - excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()} + if not os.path.exists( + os.path.join(dataset_path, "splits_dict", "excluded.pickle") + ): + set_to_exclude = set(excluded_biounits) + excluded_biounits = set() + excluded_clusters_dict = defaultdict(list) + for clusters_dict in [ + train_clusters_dict, + valid_clusters_dict, + test_clusters_dict, + ]: + for cluster in list(clusters_dict.keys()): + idx_to_exclude = [] + exclude_whole_cluster = False + for i, chain in enumerate(clusters_dict[cluster]): + if chain[0] in set_to_exclude: + if exclude_clusters: + if ( + exclude_based_on_cdr is not None + and cluster.endswith(exclude_based_on_cdr) + ): + exclude_whole_cluster = True + elif exclude_based_on_cdr is None: + exclude_whole_cluster = True + if exclude_whole_cluster: + break + excluded_clusters_dict[cluster].append(chain) + idx_to_exclude.append(i) + if exclude_whole_cluster: + excluded_clusters_dict[cluster] = clusters_dict.pop(cluster) + else: + clusters_dict[cluster] = [ + x + for i, x in enumerate(clusters_dict[cluster]) + if i not in idx_to_exclude + ] + if len(clusters_dict[cluster]) == 0: + clusters_dict.pop(cluster) + excluded_clusters_dict = { + k: list(v) for k, v in excluded_clusters_dict.items() + } + excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, []) excluded_path = os.path.join(dataset_path, "excluded") if not os.path.exists(excluded_path): os.makedirs(excluded_path) @@ -1479,16 +1492,6 @@ def _split_data( print("Moving excluded files...") for biounit in tqdm(excluded_biounits): shutil.move(os.path.join(dataset_path, biounit), excluded_path) - elif os.path.exists(os.path.join(dict_folder, "excluded.pickle")): - with open(os.path.join(dict_folder, "excluded.pickle"), "rb") as f: - excluded_clusters_dict = pickle.load(f) - excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, []) - excluded_path = os.path.join(dataset_path, "excluded") - if not os.path.exists(excluded_path): - os.makedirs(excluded_path) - print("Moving excluded files...") - for biounit in tqdm(excluded_biounits): - shutil.move(os.path.join(dataset_path, biounit), excluded_path) print("Moving files in the train set...") for biounit in tqdm(train_biounits): shutil.move(os.path.join(dataset_path, biounit), train_path)