diff --git a/generate_data.py b/generate_data.py index b724852..a3a8e5a 100644 --- a/generate_data.py +++ b/generate_data.py @@ -2,6 +2,7 @@ import gzip import click from math import ceil +from functools import partial from itertools import islice, chain from Bio import SeqIO @@ -21,39 +22,40 @@ GCS_WRITE_TIMEOUT = 60 * 30 TMP_DIR = Path('./.tmp') -# DAG functions +# functions -@solid -def fasta_to_tmp_files(context): - config = context.solid_config - clear_directory_(TMP_DIR) - - it = SeqIO.parse(config['read_from'], 'fasta') - it = filter(lambda t: len(t.seq) + len(t.description) + 10 <= config['max_seq_len'], it) - it = islice(it, 0, config['num_samples']) +def fasta_row_to_sequence_strings(config, sample): + seq = str(sample.seq) + annotation = f'[{sample.description}]' + sequences = [] - def fasta_row_to_sequence_strings(sample): - seq = str(sample.seq) - annotation = f'[{sample.description}]' - sequences = [] + seq_annot_pair = (annotation, seq) - seq_annot_pair = (annotation, seq) + if random() <= config['prob_invert_seq_annotation']: + seq_annot_pair = tuple(reversed(seq_annot_pair)) - if random() <= config['prob_invert_seq_annotation']: - seq_annot_pair = tuple(reversed(seq_annot_pair)) + sequence = ' # '.join(seq_annot_pair) + sequence = sequence.encode('utf-8') + sequences.append(sequence) - sequence = ' # '.join(seq_annot_pair) + if random() <= config['prob_seq_only']: + sequence = f'# {seq}' sequence = sequence.encode('utf-8') sequences.append(sequence) - if random() <= config['prob_seq_only']: - sequence = f'# {seq}' - sequence = sequence.encode('utf-8') - sequences.append(sequence) + return sequences - return sequences +# DAG functions - it = map(fasta_row_to_sequence_strings, it) +@solid +def fasta_to_tmp_files(context): + config = context.solid_config + clear_directory_(TMP_DIR) + + it = SeqIO.parse(config['read_from'], 'fasta') + it = filter(lambda t: len(t.seq) + len(t.description) + 10 <= config['max_seq_len'], it) + it = islice(it, 0, config['num_samples']) + it = map(partial(fasta_row_to_sequence_strings, config), it) it = chain.from_iterable(it) for index, data in enumerate(it):