diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index 1224d99..28fbf59 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -914,6 +914,8 @@ def set_cdr(self, cdr): """ if not self.sabdab: cdr = None + if isinstance(cdr, str): + cdr = [cdr] if cdr == self.cdr: return self.cdr = cdr @@ -924,12 +926,12 @@ def set_cdr(self, cdr): print(f"Setting CDR to {cdr}...") for i, data in tqdm(enumerate(self.data)): if self.clusters is not None: - if data.split("__")[1] == cdr: + if data.split("__")[1] in cdr: self.indices.append(i) else: add = False for chain in self.files[data]: - if chain.split("__")[1] == cdr: + if chain.split("__")[1] in cdr: add = True break if add: