Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 5, 2021
1 parent 9b64dab commit 9cd5e65
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gzip
import click
from math import ceil
from functools import partial
from itertools import islice, chain
from Bio import SeqIO

Expand All @@ -21,39 +22,40 @@
GCS_WRITE_TIMEOUT = 60 * 30
TMP_DIR = Path('./.tmp')

# DAG functions
# functions

@solid
def fasta_to_tmp_files(context):
config = context.solid_config
clear_directory_(TMP_DIR)

it = SeqIO.parse(config['read_from'], 'fasta')
it = filter(lambda t: len(t.seq) + len(t.description) + 10 <= config['max_seq_len'], it)
it = islice(it, 0, config['num_samples'])
def fasta_row_to_sequence_strings(config, sample):
seq = str(sample.seq)
annotation = f'[{sample.description}]'
sequences = []

def fasta_row_to_sequence_strings(sample):
seq = str(sample.seq)
annotation = f'[{sample.description}]'
sequences = []
seq_annot_pair = (annotation, seq)

seq_annot_pair = (annotation, seq)
if random() <= config['prob_invert_seq_annotation']:
seq_annot_pair = tuple(reversed(seq_annot_pair))

if random() <= config['prob_invert_seq_annotation']:
seq_annot_pair = tuple(reversed(seq_annot_pair))
sequence = ' # '.join(seq_annot_pair)
sequence = sequence.encode('utf-8')
sequences.append(sequence)

sequence = ' # '.join(seq_annot_pair)
if random() <= config['prob_seq_only']:
sequence = f'# {seq}'
sequence = sequence.encode('utf-8')
sequences.append(sequence)

if random() <= config['prob_seq_only']:
sequence = f'# {seq}'
sequence = sequence.encode('utf-8')
sequences.append(sequence)
return sequences

return sequences
# DAG functions

it = map(fasta_row_to_sequence_strings, it)
@solid
def fasta_to_tmp_files(context):
config = context.solid_config
clear_directory_(TMP_DIR)

it = SeqIO.parse(config['read_from'], 'fasta')
it = filter(lambda t: len(t.seq) + len(t.description) + 10 <= config['max_seq_len'], it)
it = islice(it, 0, config['num_samples'])
it = map(partial(fasta_row_to_sequence_strings, config), it)
it = chain.from_iterable(it)

for index, data in enumerate(it):
Expand Down

0 comments on commit 9cd5e65

Please sign in to comment.