From a8c1bbae9aea78974e84b59c08d15a70304946a1 Mon Sep 17 00:00:00 2001 From: David Prihoda Date: Mon, 10 Jun 2024 10:24:50 +0200 Subject: [PATCH] Add Chain.batch() for parsing multiple seqs at once --- abnumber/__version__.py | 2 +- abnumber/chain.py | 53 +++++++++++++++++++++++++++++++++++++---- abnumber/common.py | 41 ++++++++++++++++--------------- test/test_chain.py | 21 ++++++++++++++++ 4 files changed, 92 insertions(+), 25 deletions(-) diff --git a/abnumber/__version__.py b/abnumber/__version__.py index 73e3bb4..80eb7f9 100644 --- a/abnumber/__version__.py +++ b/abnumber/__version__.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = '0.3.3' diff --git a/abnumber/chain.py b/abnumber/chain.py index b3ba624..e6c87d5 100644 --- a/abnumber/chain.py +++ b/abnumber/chain.py @@ -1,7 +1,5 @@ from collections import OrderedDict from typing import Union, List, Generator, Tuple -from Bio import SeqIO -from Bio.SeqRecord import SeqRecord import pandas as pd from abnumber.alignment import Alignment @@ -9,6 +7,8 @@ is_integer, SCHEME_BORDERS, _get_unique_chains from abnumber.exceptions import ChainParseError import numpy as np +from Bio import SeqIO +from Bio.SeqRecord import SeqRecord from Bio.Seq import Seq from abnumber.position import Position @@ -83,6 +83,8 @@ def __init__(self, sequence, scheme, cdr_definition=None, name=None, assign_germ else: if sequence is None: raise ChainParseError('Expected sequence, got None') + if isinstance(sequence, list): + raise ChainParseError('Expected string or Seq, got list. Please use Chain.batch() to parse multiple sequences') if not isinstance(sequence, str) and not isinstance(sequence, Seq): raise ChainParseError(f'Expected string or Seq, got {type(sequence)}: {sequence}') if '-' in sequence: @@ -93,7 +95,9 @@ def __init__(self, sequence, scheme, cdr_definition=None, name=None, assign_germ raise ChainParseError('Do not use tail= when providing sequence=, it will be inferred automatically') if isinstance(sequence, Seq): sequence = str(sequence) - results = _anarci_align(sequence, scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline) + results = _anarci_align([sequence], scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline)[0] + if not results: + raise ChainParseError(f'Variable chain sequence not recognized: "{sequence}"') if len(results) > 1: raise ChainParseError(f'Found {len(results)} antibody domains in sequence: "{sequence}"') aa_dict, chain_type, tail, species, v_gene, j_gene = results[0] @@ -157,10 +161,10 @@ def _init_from_dict(self, aa_dict, allowed_species): else: seq = ''.join(aa_dict[pos] for pos in sorted_positions) renumbered_aa_dict = _anarci_align( - seq, + [seq], scheme=self.cdr_definition if self.cdr_definition != 'north' else 'chothia', allowed_species=allowed_species - )[0][0] + )[0][0][0] cdr_definition_positions = [pos.number for pos in sorted(renumbered_aa_dict.keys())] combined_aa_dict = {} for orig_pos, cdr_definition_position in zip(sorted_positions, cdr_definition_positions): @@ -178,6 +182,45 @@ def _init_from_dict(self, aa_dict, allowed_species): region_idx += 1 regions_list[region_idx][pos] = aa + @classmethod + def batch(cls, seq_dict: dict, scheme: str, cdr_definition=None, assign_germline=False, allowed_species=None): + """Create multiple Chain objects from dict of sequences + + :param seq_dict: Dictionary of sequence strings, keys are sequence identifiers + :param scheme: Numbering scheme to align the sequences + :param cdr_definition: Numbering scheme to be used for definition of CDR regions. Same as ``scheme`` by default. + :param assign_germline: Assign germline name using ANARCI based on best sequence identity + :param allowed_species: Allowed species for germline assignment. Use ``None`` to allow all species, or one or more of: ``'human', 'mouse','rat','rabbit','rhesus','pig','alpaca'`` + :return: tuple with (dict of Chain objects, dict of error strings) + """ + assert isinstance(seq_dict, dict), f'Expected dictionary of sequences, got: {type(seq_dict).__name__}' + names = list(seq_dict.keys()) + seq_list = list(seq_dict.values()) + all_results = _anarci_align(seq_list, scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline) + names = names or ([None] * len(seq_list)) + chains = {} + errors = {} + for sequence, results, name in zip(seq_list, all_results, names): + if not results: + errors[name] = f'Variable chain sequence not recognized: "{sequence}"' + elif len(results) > 1: + errors[name] = f'Found {len(results)} antibody domains: "{sequence}"' + else: + aa_dict, chain_type, tail, species, v_gene, j_gene = results[0] + chains[name] = Chain( + sequence=None, + aa_dict=aa_dict, + name=name, + scheme=scheme, + chain_type=chain_type, + cdr_definition=cdr_definition, + tail=tail, + species=species, + v_gene=v_gene, + j_gene=j_gene + ) + return chains, errors + def __repr__(self): return self.format() diff --git a/abnumber/common.py b/abnumber/common.py index 8259602..944da0f 100644 --- a/abnumber/common.py +++ b/abnumber/common.py @@ -20,31 +20,34 @@ def _validate_chain_type(chain_type): f'Invalid chain type "{chain_type}", it should be "H" (heavy), "L" (lambda light chian) or "K" (kappa light chain)' -def _anarci_align(sequence, scheme, allowed_species, assign_germline=False) -> List[Tuple]: +def _anarci_align(sequences, scheme, allowed_species, assign_germline=False) -> List[List[Tuple]]: from abnumber.position import Position - sequence = re.sub(WHITESPACE, '', sequence) + assert isinstance(sequences, list), f'Expected list of sequences, got: {type(sequences)}' all_numbered, all_ali, all_hits = anarci( - [('id', sequence)], + [(f'id{i}', re.sub(WHITESPACE, '', sequence)) for i, sequence in enumerate(sequences)], scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline ) - seq_numbered = all_numbered[0] - seq_ali = all_ali[0] - if seq_numbered is None: - raise ChainParseError(f'Variable chain sequence not recognized: "{sequence}"') - assert len(seq_numbered) == len(seq_ali), 'Unexpected ANARCI output' - results = [] - for (positions, start, end), ali in zip(seq_numbered, seq_ali): - chain_type = ali['chain_type'] - species = ali['species'] - v_gene = ali['germlines']['v_gene'][0][1] if assign_germline else None - j_gene = ali['germlines']['j_gene'][0][1] if assign_germline else None - aa_dict = {Position(chain_type=chain_type, number=num, letter=letter, scheme=scheme): aa - for (num, letter), aa in positions if aa != '-'} - tail = sequence[end+1:] - results.append((aa_dict, chain_type, tail, species, v_gene, j_gene)) - return results + all_results = [] + for sequence, seq_numbered, seq_ali in zip(sequences, all_numbered, all_ali): + if seq_numbered is None: + # Variable chain sequence not recognized + all_results.append([]) + continue + assert len(seq_numbered) == len(seq_ali), 'Unexpected ANARCI output' + results = [] + for (positions, start, end), ali in zip(seq_numbered, seq_ali): + chain_type = ali['chain_type'] + species = ali['species'] + v_gene = ali['germlines']['v_gene'][0][1] if assign_germline else None + j_gene = ali['germlines']['j_gene'][0][1] if assign_germline else None + aa_dict = {Position(chain_type=chain_type, number=num, letter=letter, scheme=scheme): aa + for (num, letter), aa in positions if aa != '-'} + tail = sequence[end+1:] + results.append((aa_dict, chain_type, tail, species, v_gene, j_gene)) + all_results.append(results) + return all_results def _get_unique_chains(chains): diff --git a/test/test_chain.py b/test/test_chain.py index d21c0d9..cd925fb 100644 --- a/test/test_chain.py +++ b/test/test_chain.py @@ -121,6 +121,11 @@ def test_invalid_chain_raises_error(scheme): Chain('AAA', scheme=scheme) +def test_multiple_chains_raises_error(): + with pytest.raises(ChainParseError): + Chain('QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSSQVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS', scheme='imgt') + + def test_aho_without_cdr_definition_raises_error(): with pytest.raises(ValueError): Chain('QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS', scheme='aho') @@ -247,3 +252,19 @@ def test_nearest_j_region(): assert nearest_j[0].name == 'IGHJ6*01' + +def test_batch(): + chains, errors = Chain.batch({ + 'A': 'QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTVTVSS', + 'B': 'EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS', + 'C': 'FOO', + 'D': 'EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSSEVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS' + }, scheme='imgt') + assert len(chains) == 2 + assert chains['A'].raw[0] == 'Q' + assert chains['B'].raw[0] == 'E' + assert 'C' not in chains + assert errors['C'] == 'Variable chain sequence not recognized: "FOO"' + assert 'D' not in chains + assert errors['D'] == 'Found 2 antibody domains: "EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSSEVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS"' +