Skip to content

Commit

Permalink
Add Chain.batch() for parsing multiple seqs at once
Browse files Browse the repository at this point in the history
  • Loading branch information
prihoda committed Jun 10, 2024
1 parent e5d8e2b commit a8c1bba
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 25 deletions.
2 changes: 1 addition & 1 deletion abnumber/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.2'
__version__ = '0.3.3'
53 changes: 48 additions & 5 deletions abnumber/chain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections import OrderedDict
from typing import Union, List, Generator, Tuple
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
import pandas as pd

from abnumber.alignment import Alignment
from abnumber.common import _anarci_align, _validate_chain_type, SUPPORTED_SCHEMES, SUPPORTED_CDR_DEFINITIONS, \
is_integer, SCHEME_BORDERS, _get_unique_chains
from abnumber.exceptions import ChainParseError
import numpy as np
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq

from abnumber.position import Position
Expand Down Expand Up @@ -83,6 +83,8 @@ def __init__(self, sequence, scheme, cdr_definition=None, name=None, assign_germ
else:
if sequence is None:
raise ChainParseError('Expected sequence, got None')
if isinstance(sequence, list):
raise ChainParseError('Expected string or Seq, got list. Please use Chain.batch() to parse multiple sequences')
if not isinstance(sequence, str) and not isinstance(sequence, Seq):
raise ChainParseError(f'Expected string or Seq, got {type(sequence)}: {sequence}')
if '-' in sequence:
Expand All @@ -93,7 +95,9 @@ def __init__(self, sequence, scheme, cdr_definition=None, name=None, assign_germ
raise ChainParseError('Do not use tail= when providing sequence=, it will be inferred automatically')
if isinstance(sequence, Seq):
sequence = str(sequence)
results = _anarci_align(sequence, scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline)
results = _anarci_align([sequence], scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline)[0]
if not results:
raise ChainParseError(f'Variable chain sequence not recognized: "{sequence}"')
if len(results) > 1:
raise ChainParseError(f'Found {len(results)} antibody domains in sequence: "{sequence}"')
aa_dict, chain_type, tail, species, v_gene, j_gene = results[0]
Expand Down Expand Up @@ -157,10 +161,10 @@ def _init_from_dict(self, aa_dict, allowed_species):
else:
seq = ''.join(aa_dict[pos] for pos in sorted_positions)
renumbered_aa_dict = _anarci_align(
seq,
[seq],
scheme=self.cdr_definition if self.cdr_definition != 'north' else 'chothia',
allowed_species=allowed_species
)[0][0]
)[0][0][0]
cdr_definition_positions = [pos.number for pos in sorted(renumbered_aa_dict.keys())]
combined_aa_dict = {}
for orig_pos, cdr_definition_position in zip(sorted_positions, cdr_definition_positions):
Expand All @@ -178,6 +182,45 @@ def _init_from_dict(self, aa_dict, allowed_species):
region_idx += 1
regions_list[region_idx][pos] = aa

@classmethod
def batch(cls, seq_dict: dict, scheme: str, cdr_definition=None, assign_germline=False, allowed_species=None):
"""Create multiple Chain objects from dict of sequences
:param seq_dict: Dictionary of sequence strings, keys are sequence identifiers
:param scheme: Numbering scheme to align the sequences
:param cdr_definition: Numbering scheme to be used for definition of CDR regions. Same as ``scheme`` by default.
:param assign_germline: Assign germline name using ANARCI based on best sequence identity
:param allowed_species: Allowed species for germline assignment. Use ``None`` to allow all species, or one or more of: ``'human', 'mouse','rat','rabbit','rhesus','pig','alpaca'``
:return: tuple with (dict of Chain objects, dict of error strings)
"""
assert isinstance(seq_dict, dict), f'Expected dictionary of sequences, got: {type(seq_dict).__name__}'
names = list(seq_dict.keys())
seq_list = list(seq_dict.values())
all_results = _anarci_align(seq_list, scheme=scheme, allowed_species=allowed_species, assign_germline=assign_germline)
names = names or ([None] * len(seq_list))
chains = {}
errors = {}
for sequence, results, name in zip(seq_list, all_results, names):
if not results:
errors[name] = f'Variable chain sequence not recognized: "{sequence}"'
elif len(results) > 1:
errors[name] = f'Found {len(results)} antibody domains: "{sequence}"'
else:
aa_dict, chain_type, tail, species, v_gene, j_gene = results[0]
chains[name] = Chain(
sequence=None,
aa_dict=aa_dict,
name=name,
scheme=scheme,
chain_type=chain_type,
cdr_definition=cdr_definition,
tail=tail,
species=species,
v_gene=v_gene,
j_gene=j_gene
)
return chains, errors

def __repr__(self):
return self.format()

Expand Down
41 changes: 22 additions & 19 deletions abnumber/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,34 @@ def _validate_chain_type(chain_type):
f'Invalid chain type "{chain_type}", it should be "H" (heavy), "L" (lambda light chian) or "K" (kappa light chain)'


def _anarci_align(sequence, scheme, allowed_species, assign_germline=False) -> List[Tuple]:
def _anarci_align(sequences, scheme, allowed_species, assign_germline=False) -> List[List[Tuple]]:
from abnumber.position import Position
sequence = re.sub(WHITESPACE, '', sequence)
assert isinstance(sequences, list), f'Expected list of sequences, got: {type(sequences)}'
all_numbered, all_ali, all_hits = anarci(
[('id', sequence)],
[(f'id{i}', re.sub(WHITESPACE, '', sequence)) for i, sequence in enumerate(sequences)],
scheme=scheme,
allowed_species=allowed_species,
assign_germline=assign_germline
)
seq_numbered = all_numbered[0]
seq_ali = all_ali[0]
if seq_numbered is None:
raise ChainParseError(f'Variable chain sequence not recognized: "{sequence}"')
assert len(seq_numbered) == len(seq_ali), 'Unexpected ANARCI output'
results = []
for (positions, start, end), ali in zip(seq_numbered, seq_ali):
chain_type = ali['chain_type']
species = ali['species']
v_gene = ali['germlines']['v_gene'][0][1] if assign_germline else None
j_gene = ali['germlines']['j_gene'][0][1] if assign_germline else None
aa_dict = {Position(chain_type=chain_type, number=num, letter=letter, scheme=scheme): aa
for (num, letter), aa in positions if aa != '-'}
tail = sequence[end+1:]
results.append((aa_dict, chain_type, tail, species, v_gene, j_gene))
return results
all_results = []
for sequence, seq_numbered, seq_ali in zip(sequences, all_numbered, all_ali):
if seq_numbered is None:
# Variable chain sequence not recognized
all_results.append([])
continue
assert len(seq_numbered) == len(seq_ali), 'Unexpected ANARCI output'
results = []
for (positions, start, end), ali in zip(seq_numbered, seq_ali):
chain_type = ali['chain_type']
species = ali['species']
v_gene = ali['germlines']['v_gene'][0][1] if assign_germline else None
j_gene = ali['germlines']['j_gene'][0][1] if assign_germline else None
aa_dict = {Position(chain_type=chain_type, number=num, letter=letter, scheme=scheme): aa
for (num, letter), aa in positions if aa != '-'}
tail = sequence[end+1:]
results.append((aa_dict, chain_type, tail, species, v_gene, j_gene))
all_results.append(results)
return all_results


def _get_unique_chains(chains):
Expand Down
21 changes: 21 additions & 0 deletions test/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def test_invalid_chain_raises_error(scheme):
Chain('AAA', scheme=scheme)


def test_multiple_chains_raises_error():
with pytest.raises(ChainParseError):
Chain('QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSSQVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS', scheme='imgt')


def test_aho_without_cdr_definition_raises_error():
with pytest.raises(ValueError):
Chain('QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS', scheme='aho')
Expand Down Expand Up @@ -247,3 +252,19 @@ def test_nearest_j_region():

assert nearest_j[0].name == 'IGHJ6*01'


def test_batch():
chains, errors = Chain.batch({
'A': 'QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTVTVSS',
'B': 'EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS',
'C': 'FOO',
'D': 'EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSSEVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS'
}, scheme='imgt')
assert len(chains) == 2
assert chains['A'].raw[0] == 'Q'
assert chains['B'].raw[0] == 'E'
assert 'C' not in chains
assert errors['C'] == 'Variable chain sequence not recognized: "FOO"'
assert 'D' not in chains
assert errors['D'] == 'Found 2 antibody domains: "EVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSSEVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYSEDDERGHYCLDYWGQGTTLTVSS"'

0 comments on commit a8c1bba

Please sign in to comment.