Skip to content

Commit

Permalink
Merge pull request #24 from adaptyvbio/shuffle_option
Browse files Browse the repository at this point in the history
Fix ProteinLoader bugs
  • Loading branch information
elkoz authored Feb 27, 2023
2 parents 0cf99cc + 7d01371 commit 192a4a9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 41 deletions.
113 changes: 85 additions & 28 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ <h3 id="using-the-data">Using the data</h3>
)
paths = [item for sublist in paths for item in sublist]
error_ids = [x for x in paths if not x.endswith(&#34;.gz&#34;)]
paths = [x for x in paths if x.endswith(&#34;.gz&#34;)]
if load_live:
print(&#34;Download newest structure files...&#34;)
live_paths = p_map(download_live, error_ids)
Expand Down Expand Up @@ -1238,6 +1239,7 @@ <h3 id="using-the-data">Using the data</h3>
debug_file_path=None,
entry_type=&#34;biounit&#34;, # biounit, chain, pair
classes_to_exclude=None, # heteromers, homomers, single_chains
shuffle_clusters=True,
):
&#34;&#34;&#34;
Parameters
Expand Down Expand Up @@ -1269,6 +1271,8 @@ <h3 id="using-the-data">Using the data</h3>
for chain-chain pairs (all pairs that are seen in the same biounit))
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `&#34;single_chains&#34;`, `&#34;heteromers&#34;`, `&#34;homomers&#34;`)
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
&#34;&#34;&#34;

alphabet = ALPHABET
Expand All @@ -1280,8 +1284,11 @@ <h3 id="using-the-data">Using the data</h3>
self.loaded = None
self.dataset_folder = dataset_folder
self.features_folder = features_folder
self.feature_types = node_features_type.split(&#34;+&#34;)
self.feature_types = []
if node_features_type is not None:
self.feature_types = node_features_type.split(&#34;+&#34;)
self.entry_type = entry_type
self.shuffle_clusters = shuffle_clusters

if debug_file_path is not None:
self.dataset_folder = os.path.dirname(debug_file_path)
Expand Down Expand Up @@ -1329,13 +1336,18 @@ <h3 id="using-the-data">Using the data</h3>
with open(file, &#34;rb&#34;) as f:
data = pickle.load(f)
if len(data[&#34;S&#34;]) &gt; max_length:
to_remove.append((id, chain, file))
for id, chain, file in to_remove:
self.files[id][chain].remove(file)
if len(self.files[id][chain]) == 0:
self.files[id].pop(chain)
if len(self.files[id]) == 0:
self.files.pop(id)
to_remove.append(file)
for id in list(self.files.keys()):
chain_dict = self.files[id]
for chain in list(chain_dict.keys()):
file_list = chain_dict[chain]
for file in file_list:
if file in to_remove:
self.files[id][chain].remove(file)
if len(self.files[id][chain]) == 0:
self.files[id].pop(chain)
if len(self.files[id]) == 0:
self.files.pop(id)
# load the clusters
if classes_to_exclude is None:
classes_to_exclude = []
Expand Down Expand Up @@ -1612,8 +1624,14 @@ <h3 id="using-the-data">Using the data</h3>
else:
cluster = self.data[idx]
id = None
while id not in self.files: # some IDs can be filtered out by length
chain_n = random.randint(0, len(self.clusters[cluster]) - 1)
chain_n = -1
while (
id is None or len(self.files[id][chain_id]) == 0
): # some IDs can be filtered out by length
if self.shuffle_clusters:
chain_n = random.randint(0, len(self.clusters[cluster]) - 1)
else:
chain_n += 1
id, chain_id = self.clusters[cluster][
chain_n
] # get id and chain from cluster
Expand Down Expand Up @@ -1660,7 +1678,7 @@ <h3 id="using-the-data">Using the data</h3>
load_to_ram=False,
debug=False,
interpolate=&#34;none&#34;,
node_features_type=&#34;zeros&#34;,
node_features_type=None,
batch_size=4,
entry_type=&#34;biounit&#34;, # biounit, chain, pair
classes_to_exclude=None,
Expand All @@ -1670,6 +1688,8 @@ <h3 id="using-the-data">Using the data</h3>
mask_whole_chains=False,
mask_frac=None,
force_binding_sites_frac=0,
shuffle_clusters=True,
shuffle_batches=True,
) -&gt; None:
&#34;&#34;&#34;
Parameters
Expand All @@ -1692,7 +1712,7 @@ <h3 id="using-the-data">Using the data</h3>
only process 1000 files
interpolate : {&#34;none&#34;, &#34;only_middle&#34;, &#34;all&#34;}
`&#34;none&#34;` for no interpolation, `&#34;only_middle&#34;` for only linear interpolation in the middle, `&#34;all&#34;` for linear interpolation + ends generation
node_features_type : {&#34;zeros&#34;, &#34;dihedral&#34;, &#34;sidechain_orientation&#34;, &#34;chemical&#34;, &#34;secondary_structure&#34; or combinations with &#34;+&#34;}
node_features_type : {&#34;dihedral&#34;, &#34;sidechain_orientation&#34;, &#34;chemical&#34;, &#34;secondary_structure&#34; or combinations with &#34;+&#34;}, optional
the type of node features, e.g. `&#34;dihedral&#34;` or `&#34;sidechain_orientation+chemical&#34;`
batch_size : int, default 4
the batch size
Expand All @@ -1711,6 +1731,10 @@ <h3 id="using-the-data">Using the data</h3>
force_binding_sites_frac : float, default 0
if &gt; 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be
forced to be in a binding site
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
shuffle_batches : bool, default True
if `True`, the batches are shuffled at each epoch
&#34;&#34;&#34;

dataset = ProteinDataset(
Expand All @@ -1726,6 +1750,7 @@ <h3 id="using-the-data">Using the data</h3>
node_features_type=node_features_type,
entry_type=entry_type,
classes_to_exclude=classes_to_exclude,
shuffle_clusters=shuffle_clusters,
)
super().__init__(
dataset,
Expand All @@ -1738,6 +1763,7 @@ <h3 id="using-the-data">Using the data</h3>
force_binding_sites_frac=force_binding_sites_frac,
),
batch_size=batch_size,
shuffle=shuffle_batches,
)


Expand Down Expand Up @@ -2330,7 +2356,7 @@ <h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="proteinflow.ProteinDataset"><code class="flex name class">
<span>class <span class="ident">ProteinDataset</span></span>
<span>(</span><span>dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type='zeros', debug_file_path=None, entry_type='biounit', classes_to_exclude=None)</span>
<span>(</span><span>dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type='zeros', debug_file_path=None, entry_type='biounit', classes_to_exclude=None, shuffle_clusters=True)</span>
</code></dt>
<dd>
<div class="desc"><p>Dataset to load proteinflow data</p>
Expand Down Expand Up @@ -2385,6 +2411,8 @@ <h2 id="parameters">Parameters</h2>
for chain-chain pairs (all pairs that are seen in the same biounit))</dd>
<dt><strong><code>classes_to_exclude</code></strong> :&ensp;<code>list</code> of <code>str</code>, optional</dt>
<dd>a list of classes to exclude from the dataset (select from <code>"single_chains"</code>, <code>"heteromers"</code>, <code>"homomers"</code>)</dd>
<dt><strong><code>shuffle_clusters</code></strong> :&ensp;<code>bool</code>, default <code>True</code></dt>
<dd>if <code>True</code>, a new representative is randomly selected for each cluster at each epoch (if <code>clustering_dict_path</code> is given)</dd>
</dl></div>
<details class="source">
<summary>
Expand Down Expand Up @@ -2434,6 +2462,7 @@ <h2 id="parameters">Parameters</h2>
debug_file_path=None,
entry_type=&#34;biounit&#34;, # biounit, chain, pair
classes_to_exclude=None, # heteromers, homomers, single_chains
shuffle_clusters=True,
):
&#34;&#34;&#34;
Parameters
Expand Down Expand Up @@ -2465,6 +2494,8 @@ <h2 id="parameters">Parameters</h2>
for chain-chain pairs (all pairs that are seen in the same biounit))
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `&#34;single_chains&#34;`, `&#34;heteromers&#34;`, `&#34;homomers&#34;`)
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
&#34;&#34;&#34;

alphabet = ALPHABET
Expand All @@ -2476,8 +2507,11 @@ <h2 id="parameters">Parameters</h2>
self.loaded = None
self.dataset_folder = dataset_folder
self.features_folder = features_folder
self.feature_types = node_features_type.split(&#34;+&#34;)
self.feature_types = []
if node_features_type is not None:
self.feature_types = node_features_type.split(&#34;+&#34;)
self.entry_type = entry_type
self.shuffle_clusters = shuffle_clusters

if debug_file_path is not None:
self.dataset_folder = os.path.dirname(debug_file_path)
Expand Down Expand Up @@ -2525,13 +2559,18 @@ <h2 id="parameters">Parameters</h2>
with open(file, &#34;rb&#34;) as f:
data = pickle.load(f)
if len(data[&#34;S&#34;]) &gt; max_length:
to_remove.append((id, chain, file))
for id, chain, file in to_remove:
self.files[id][chain].remove(file)
if len(self.files[id][chain]) == 0:
self.files[id].pop(chain)
if len(self.files[id]) == 0:
self.files.pop(id)
to_remove.append(file)
for id in list(self.files.keys()):
chain_dict = self.files[id]
for chain in list(chain_dict.keys()):
file_list = chain_dict[chain]
for file in file_list:
if file in to_remove:
self.files[id][chain].remove(file)
if len(self.files[id][chain]) == 0:
self.files[id].pop(chain)
if len(self.files[id]) == 0:
self.files.pop(id)
# load the clusters
if classes_to_exclude is None:
classes_to_exclude = []
Expand Down Expand Up @@ -2808,8 +2847,14 @@ <h2 id="parameters">Parameters</h2>
else:
cluster = self.data[idx]
id = None
while id not in self.files: # some IDs can be filtered out by length
chain_n = random.randint(0, len(self.clusters[cluster]) - 1)
chain_n = -1
while (
id is None or len(self.files[id][chain_id]) == 0
): # some IDs can be filtered out by length
if self.shuffle_clusters:
chain_n = random.randint(0, len(self.clusters[cluster]) - 1)
else:
chain_n += 1
id, chain_id = self.clusters[cluster][
chain_n
] # get id and chain from cluster
Expand All @@ -2831,7 +2876,7 @@ <h3>Ancestors</h3>
</dd>
<dt id="proteinflow.ProteinLoader"><code class="flex name class">
<span>class <span class="ident">ProteinLoader</span></span>
<span>(</span><span>dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type='zeros', batch_size=4, entry_type='biounit', classes_to_exclude=None, lower_limit=15, upper_limit=100, mask_residues=True, mask_whole_chains=False, mask_frac=None, force_binding_sites_frac=0)</span>
<span>(</span><span>dataset_folder, features_folder='./data/tmp/', clustering_dict_path=None, max_length=None, rewrite=False, use_fraction=1, load_to_ram=False, debug=False, interpolate='none', node_features_type=None, batch_size=4, entry_type='biounit', classes_to_exclude=None, lower_limit=15, upper_limit=100, mask_residues=True, mask_whole_chains=False, mask_frac=None, force_binding_sites_frac=0, shuffle_clusters=True, shuffle_batches=True)</span>
</code></dt>
<dd>
<div class="desc"><p>A subclass of <code>torch.data.utils.DataLoader</code> tuned for the <code><a title="proteinflow" href="#proteinflow">proteinflow</a></code> dataset</p>
Expand Down Expand Up @@ -2867,7 +2912,7 @@ <h2 id="parameters">Parameters</h2>
<dd>only process 1000 files</dd>
<dt><strong><code>interpolate</code></strong> :&ensp;<code>{"none", "only_middle", "all"}</code></dt>
<dd><code>"none"</code> for no interpolation, <code>"only_middle"</code> for only linear interpolation in the middle, <code>"all"</code> for linear interpolation + ends generation</dd>
<dt><strong><code>node_features_type</code></strong> :&ensp;<code>{"zeros", "dihedral", "sidechain_orientation", "chemical", "secondary_structure"</code> or <code>combinations with "+"}</code></dt>
<dt><strong><code>node_features_type</code></strong> :&ensp;<code>{"dihedral", "sidechain_orientation", "chemical", "secondary_structure"</code> or <code>combinations with "+"}</code>, optional</dt>
<dd>the type of node features, e.g. <code>"dihedral"</code> or <code>"sidechain_orientation+chemical"</code></dd>
<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, default <code>4</code></dt>
<dd>the batch size</dd>
Expand All @@ -2886,6 +2931,10 @@ <h2 id="parameters">Parameters</h2>
<dt><strong><code>force_binding_sites_frac</code></strong> :&ensp;<code>float</code>, default <code>0</code></dt>
<dd>if &gt; 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be
forced to be in a binding site</dd>
<dt><strong><code>shuffle_clusters</code></strong> :&ensp;<code>bool</code>, default <code>True</code></dt>
<dd>if <code>True</code>, a new representative is randomly selected for each cluster at each epoch (if <code>clustering_dict_path</code> is given)</dd>
<dt><strong><code>shuffle_batches</code></strong> :&ensp;<code>bool</code>, default <code>True</code></dt>
<dd>if <code>True</code>, the batches are shuffled at each epoch</dd>
</dl></div>
<details class="source">
<summary>
Expand Down Expand Up @@ -2923,7 +2972,7 @@ <h2 id="parameters">Parameters</h2>
load_to_ram=False,
debug=False,
interpolate=&#34;none&#34;,
node_features_type=&#34;zeros&#34;,
node_features_type=None,
batch_size=4,
entry_type=&#34;biounit&#34;, # biounit, chain, pair
classes_to_exclude=None,
Expand All @@ -2933,6 +2982,8 @@ <h2 id="parameters">Parameters</h2>
mask_whole_chains=False,
mask_frac=None,
force_binding_sites_frac=0,
shuffle_clusters=True,
shuffle_batches=True,
) -&gt; None:
&#34;&#34;&#34;
Parameters
Expand All @@ -2955,7 +3006,7 @@ <h2 id="parameters">Parameters</h2>
only process 1000 files
interpolate : {&#34;none&#34;, &#34;only_middle&#34;, &#34;all&#34;}
`&#34;none&#34;` for no interpolation, `&#34;only_middle&#34;` for only linear interpolation in the middle, `&#34;all&#34;` for linear interpolation + ends generation
node_features_type : {&#34;zeros&#34;, &#34;dihedral&#34;, &#34;sidechain_orientation&#34;, &#34;chemical&#34;, &#34;secondary_structure&#34; or combinations with &#34;+&#34;}
node_features_type : {&#34;dihedral&#34;, &#34;sidechain_orientation&#34;, &#34;chemical&#34;, &#34;secondary_structure&#34; or combinations with &#34;+&#34;}, optional
the type of node features, e.g. `&#34;dihedral&#34;` or `&#34;sidechain_orientation+chemical&#34;`
batch_size : int, default 4
the batch size
Expand All @@ -2974,6 +3025,10 @@ <h2 id="parameters">Parameters</h2>
force_binding_sites_frac : float, default 0
if &gt; 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be
forced to be in a binding site
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
shuffle_batches : bool, default True
if `True`, the batches are shuffled at each epoch
&#34;&#34;&#34;

dataset = ProteinDataset(
Expand All @@ -2989,6 +3044,7 @@ <h2 id="parameters">Parameters</h2>
node_features_type=node_features_type,
entry_type=entry_type,
classes_to_exclude=classes_to_exclude,
shuffle_clusters=shuffle_clusters,
)
super().__init__(
dataset,
Expand All @@ -3001,6 +3057,7 @@ <h2 id="parameters">Parameters</h2>
force_binding_sites_frac=force_binding_sites_frac,
),
batch_size=batch_size,
shuffle=shuffle_batches,
)</code></pre>
</details>
<h3>Ancestors</h3>
Expand Down Expand Up @@ -3105,4 +3162,4 @@ <h4><code><a title="proteinflow.ProteinLoader" href="#proteinflow.ProteinLoader"
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.10.0</a>.</p>
</footer>
</body>
</html>
</html>
Loading

0 comments on commit 192a4a9

Please sign in to comment.