Skip to content

Commit

Permalink
fix: splitting based on existing excluded.pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
elkoz committed Nov 10, 2023
1 parent 4fd2984 commit bdabe46
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 48 deletions.
2 changes: 1 addition & 1 deletion proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def split_data(
warnings.warn(
"Found an existing dictionary for excluded chains. proteinflow will load it and ignore the exclusion parameters! Run with --ignore_existing to overwrite the splitting."
)
excluded_biounits = []
excluded_biounits = None
else:
if exclude_chains_file is not None or exclude_chains is not None:
excluded_biounits = _get_excluded_files(
Expand Down
97 changes: 50 additions & 47 deletions proteinflow/split/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,14 @@ def _split_data(
"""
if excluded_biounits is None:
excluded_biounits = []
if os.path.exists(os.path.join(dataset_path, "splits_dict", "excluded.pickle")):
with open(
os.path.join(dataset_path, "splits_dict", "excluded.pickle"), "rb"
) as f:
excluded_clusters_dict = pickle.load(f)
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
else:
excluded_biounits = []

dict_folder = os.path.join(dataset_path, "splits_dict")
with open(os.path.join(dict_folder, "train.pickle"), "rb") as f:
Expand Down Expand Up @@ -1428,42 +1435,48 @@ def _split_data(
os.makedirs(test_path)

if len(excluded_biounits) > 0:
set_to_exclude = set(excluded_biounits)
excluded_biounits = set()
excluded_clusters_dict = defaultdict(list)
for clusters_dict in [
train_clusters_dict,
valid_clusters_dict,
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
idx_to_exclude = []
exclude_whole_cluster = False
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if exclude_based_on_cdr is not None and cluster.endswith(
exclude_based_on_cdr
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
exclude_whole_cluster = True
if exclude_whole_cluster:
break
excluded_clusters_dict[cluster].append(chain)
idx_to_exclude.append(i)
if exclude_whole_cluster:
excluded_clusters_dict[cluster] = clusters_dict.pop(cluster)
else:
clusters_dict[cluster] = [
x
for i, x in enumerate(clusters_dict[cluster])
if i not in idx_to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
excluded_biounits.update(set_to_exclude)
excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()}
if not os.path.exists(
os.path.join(dataset_path, "splits_dict", "excluded.pickle")
):
set_to_exclude = set(excluded_biounits)
excluded_biounits = set()
excluded_clusters_dict = defaultdict(list)
for clusters_dict in [
train_clusters_dict,
valid_clusters_dict,
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
idx_to_exclude = []
exclude_whole_cluster = False
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if (
exclude_based_on_cdr is not None
and cluster.endswith(exclude_based_on_cdr)
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
exclude_whole_cluster = True
if exclude_whole_cluster:
break
excluded_clusters_dict[cluster].append(chain)
idx_to_exclude.append(i)
if exclude_whole_cluster:
excluded_clusters_dict[cluster] = clusters_dict.pop(cluster)
else:
clusters_dict[cluster] = [
x
for i, x in enumerate(clusters_dict[cluster])
if i not in idx_to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
excluded_clusters_dict = {
k: list(v) for k, v in excluded_clusters_dict.items()
}
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
excluded_path = os.path.join(dataset_path, "excluded")
if not os.path.exists(excluded_path):
os.makedirs(excluded_path)
Expand All @@ -1479,16 +1492,6 @@ def _split_data(
print("Moving excluded files...")
for biounit in tqdm(excluded_biounits):
shutil.move(os.path.join(dataset_path, biounit), excluded_path)
elif os.path.exists(os.path.join(dict_folder, "excluded.pickle")):
with open(os.path.join(dict_folder, "excluded.pickle"), "rb") as f:
excluded_clusters_dict = pickle.load(f)
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
excluded_path = os.path.join(dataset_path, "excluded")
if not os.path.exists(excluded_path):
os.makedirs(excluded_path)
print("Moving excluded files...")
for biounit in tqdm(excluded_biounits):
shutil.move(os.path.join(dataset_path, biounit), excluded_path)
print("Moving files in the train set...")
for biounit in tqdm(train_biounits):
shutil.move(os.path.join(dataset_path, biounit), train_path)
Expand Down

0 comments on commit bdabe46

Please sign in to comment.