diff --git a/configs/data/default.yml b/configs/data/default.yml index c9c0cf4..3bc6f9e 100644 --- a/configs/data/default.yml +++ b/configs/data/default.yml @@ -1,8 +1,9 @@ read_from: './uniref50.fasta' -write_to: 'gs://progen-train-data' +write_to: './train_data' num_samples: 25000 max_seq_len: 1024 prob_seq_only: 0.1 prob_invert_seq_annotation: 0.5 fraction_valid_data: 0.025 num_sequences_per_file: 100000 +sort_annotations: true diff --git a/generate_data.py b/generate_data.py index a3a8e5a..3f3285d 100644 --- a/generate_data.py +++ b/generate_data.py @@ -1,6 +1,8 @@ import os import gzip import click +import re +import random from math import ceil from functools import partial from itertools import islice, chain @@ -24,19 +26,41 @@ # functions +def order_dict_by(d, fn): + keys = fn(d.keys()) + return dict(tuple(map(lambda k: (k, d[k]), keys))) + +def get_annotations_from_description(config, description): + taxonomy_matches = re.findall(r'Tax=([a-zA-Z]*)', description) + annotations = dict() + + if len(taxonomy_matches) > 0: + annotations['tax'] = taxonomy_matches[0] + + return annotations + def fasta_row_to_sequence_strings(config, sample): seq = str(sample.seq) - annotation = f'[{sample.description}]' sequences = [] - seq_annot_pair = (annotation, seq) + annotations = get_annotations_from_description(config, sample.description) + # todo: gather annotations from GO + + if len(annotations) > 0: + sort_annot_by = random.shuffle if not config['sort_annotations'] else sorted + annotations = order_dict_by(annotations, sort_annot_by) - if random() <= config['prob_invert_seq_annotation']: - seq_annot_pair = tuple(reversed(seq_annot_pair)) + annotation_str = [f"[{annot_name}={annot}]" for annot_name, annot in annotations.items()] + annotation_str = ' '.join(annotation_str) - sequence = ' # '.join(seq_annot_pair) - sequence = sequence.encode('utf-8') - sequences.append(sequence) + seq_annot_pair = (annotation_str, seq) + + if random() <= config['prob_invert_seq_annotation']: + seq_annot_pair = tuple(reversed(seq_annot_pair)) + + sequence = ' # '.join(seq_annot_pair) + sequence = sequence.encode('utf-8') + sequences.append(sequence) if random() <= config['prob_seq_only']: sequence = f'# {seq}'