Skip to content

Commit

Permalink
fix: update splitting to deal with existing excluded.pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
elkoz committed Nov 10, 2023
1 parent 9b1a9ad commit c828dd2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 34 deletions.
52 changes: 29 additions & 23 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,34 +520,40 @@ def split_data(
temp_folder = os.path.join(tempfile.gettempdir(), "proteinflow")
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
if exclude_chains_file is not None or exclude_chains is not None:
excluded_biounits = _get_excluded_files(
tag,
local_datasets_folder,
temp_folder,
exclude_chains,
exclude_chains_file,
exclude_threshold,
)
else:
excluded_biounits = []
if exclude_chains_without_ligands:
excluded_biounits += _exclude_files_with_no_ligand(
tag,
local_datasets_folder,
)

output_folder = os.path.join(local_datasets_folder, f"proteinflow_{tag}")
out_split_dict_folder = os.path.join(output_folder, "splits_dict")
exists = False
if ignore_existing:
shutil.rmtree(out_split_dict_folder)

if os.path.exists(out_split_dict_folder):
if not ignore_existing:
warnings.warn(
f"Found an existing dictionary for tag {tag}. proteinflow will load it and ignore the parameters! Run with --ignore_existing to overwrite."
if os.path.join(output_folder, "splits_dict", "excluded.pickle"):
warnings.warn(
"Found an existing dictionary for excluded chains. proteinflow will load it and ignore the exclusion parameters! Run with --ignore_existing to overwrite the splitting."
)
excluded_biounits = []
else:
if exclude_chains_file is not None or exclude_chains is not None:
excluded_biounits = _get_excluded_files(
tag,
local_datasets_folder,
temp_folder,
exclude_chains,
exclude_chains_file,
exclude_threshold,
)
else:
excluded_biounits = []
if exclude_chains_without_ligands:
excluded_biounits += _exclude_files_with_no_ligand(
tag,
local_datasets_folder,
)
exists = True
if not exists:

if os.path.exists(out_split_dict_folder):
warnings.warn(
f"Found an existing dictionary for tag {tag}. proteinflow will load it and ignore the parameters! Run with --ignore_existing to overwrite."
)
else:
_check_mmseqs()
random.seed(random_seed)
np.random.seed(random_seed)
Expand Down
32 changes: 21 additions & 11 deletions proteinflow/split/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def _get_excluded_files(

def _split_data(
dataset_path="./data/proteinflow_20221110/",
excluded_files=None,
excluded_biounits=None,
exclude_clusters=False,
exclude_based_on_cdr=None,
):
Expand All @@ -1399,8 +1399,8 @@ def _split_data(
If not `None`, exclude all files in a cluster if the cluster name does not end with `exclude_based_on_cdr`
"""
if excluded_files is None:
excluded_files = []
if excluded_biounits is None:
excluded_biounits = []

dict_folder = os.path.join(dataset_path, "splits_dict")
with open(os.path.join(dict_folder, "train.pickle"), "rb") as f:
Expand All @@ -1410,9 +1410,9 @@ def _split_data(
with open(os.path.join(dict_folder, "test.pickle"), "rb") as f:
test_clusters_dict = pickle.load(f)

train_biounits = _biounits_in_clusters_dict(train_clusters_dict, excluded_files)
valid_biounits = _biounits_in_clusters_dict(valid_clusters_dict, excluded_files)
test_biounits = _biounits_in_clusters_dict(test_clusters_dict, excluded_files)
train_biounits = _biounits_in_clusters_dict(train_clusters_dict, excluded_biounits)
valid_biounits = _biounits_in_clusters_dict(valid_clusters_dict, excluded_biounits)
test_biounits = _biounits_in_clusters_dict(test_clusters_dict, excluded_biounits)
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")
test_path = os.path.join(dataset_path, "test")
Expand All @@ -1427,9 +1427,9 @@ def _split_data(
if not os.path.exists(test_path):
os.makedirs(test_path)

if len(excluded_files) > 0:
set_to_exclude = set(excluded_files)
excluded_files = set()
if len(excluded_biounits) > 0:
set_to_exclude = set(excluded_biounits)
excluded_biounits = set()
excluded_clusters_dict = defaultdict(list)
for clusters_dict in [
train_clusters_dict,
Expand Down Expand Up @@ -1462,7 +1462,7 @@ def _split_data(
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
excluded_files.update(set_to_exclude)
excluded_biounits.update(set_to_exclude)
excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()}
excluded_path = os.path.join(dataset_path, "excluded")
if not os.path.exists(excluded_path):
Expand All @@ -1477,7 +1477,17 @@ def _split_data(
with open(os.path.join(dict_folder, "excluded.pickle"), "wb") as f:
pickle.dump(excluded_clusters_dict, f)
print("Moving excluded files...")
for biounit in tqdm(excluded_files):
for biounit in tqdm(excluded_biounits):
shutil.move(os.path.join(dataset_path, biounit), excluded_path)
elif os.path.exists(os.path.join(dict_folder, "excluded.pickle")):
with open(os.path.join(dict_folder, "excluded.pickle"), "rb") as f:
excluded_clusters_dict = pickle.load(f)
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
excluded_path = os.path.join(dataset_path, "excluded")
if not os.path.exists(excluded_path):
os.makedirs(excluded_path)
print("Moving excluded files...")
for biounit in tqdm(excluded_biounits):
shutil.move(os.path.join(dataset_path, biounit), excluded_path)
print("Moving files in the train set...")
for biounit in tqdm(train_biounits):
Expand Down

0 comments on commit c828dd2

Please sign in to comment.