diff --git a/.gitignore b/.gitignore index da94a5be..f298e2c8 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ nosetests.xml TreeSAPP.Rproj .Rhistory .RData +.python-version* # Data diff --git a/CHANGELOG.md b/CHANGELOG.md index 088cf507..4ddb20d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## [0.11.5] - 2022-06 +### Added +### Fixed +### Changed +- Switched BMGE for ClipKit (#67) + ## [0.11.4] - 2022-05-22 ### Added - Centroid inference for pOTUs based on the midpoint, or balance point, of all cluster members. diff --git a/requirements.txt b/requirements.txt index 2fa0933a..6091d3f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ tqdm >=4.50.0 pytest >=6.2.5 pandas >=1.1.0 matplotlib >=3.3.0 +clipkit >= 1.3.0 diff --git a/setup.py b/setup.py index cebfb043..f80cd53d 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -from setuptools import Extension from setuptools import setup, find_packages with open("README.md", "r") as readme: @@ -22,7 +21,7 @@ SETUP_METADATA = \ { "name": "treesapp", - "version": "0.11.4", + "version": "0.11.5", "description": "TreeSAPP is a functional and taxonomic annotation tool for genomes and metagenomes.", "long_description": LONG_DESCRIPTION, "long_description_content_type": "text/markdown", @@ -34,16 +33,12 @@ "include_package_data": True, "entry_points": {'console_scripts': ['treesapp = treesapp.__main__:main']}, "classifiers": CLASSIFIERS, - "ext_modules": [Extension("_tree_parser", - sources=["treesapp/extensions/tree_parsermodule.cpp"], - language="c++") - ], "install_requires": open("requirements.txt").read().splitlines(), "setup_requires": [ "setuptools>=50.0.0" ], "extras_require": { - 'test': ['pytest', 'pytest-cov'], + 'tests': ['pytest', 'pytest-cov'], } } diff --git a/tests/test_assign.py b/tests/test_assign.py index 8db92a2a..cd0cbc29 100644 --- a/tests/test_assign.py +++ b/tests/test_assign.py @@ -243,6 +243,55 @@ def test_load_refpkg_classifiers(self): return + def test_bin_hmm_matches_by_region(self): + + return + + def test_bin_hmm_matches_by_identity(self): + return + + def test_write_grouped_fastas(self): + from treesapp import assign + from treesapp import fasta + import random + from string import ascii_uppercase + seq_dict = {} + seq_name_idx = {} + rp_name = "PuhA" + scaler = 3 + n_seqs = self.refpkg_dict[rp_name].num_seqs * scaler + + # Test with empty inputs + fasta_file_group_map = assign.write_grouped_fastas(extracted_seq_dict=seq_dict, + seq_name_index=seq_name_idx, + refpkg_dict=self.refpkg_dict, + output_dir=self.output_dir) + self.assertEqual({}, fasta_file_group_map) + + # Test real condition + seq_dict.update({rp_name: {"99": {-1*n: ''.join(random.choice(ascii_uppercase) for _ in range(50)) + for n in range(n_seqs)}}}) + seq_name_idx.update({rp_name: {-1*x: "seq_{}".format(x) for x in range(n_seqs)}}) + fasta_file_group_map = assign.write_grouped_fastas(extracted_seq_dict=seq_dict, + seq_name_index=seq_name_idx, + refpkg_dict=self.refpkg_dict, + output_dir=self.output_dir) + fasta_files = list(fasta_file_group_map[rp_name]) + self.assertEqual(scaler, len(fasta_files)) + self.assertEqual([os.path.join(self.output_dir, + "{}_hmm_purified_group{}.faa".format(rp_name, n)) for n in range(scaler)], + fasta_files) + # Ensure there are the right number of sequences in each file + for file_path in fasta_files: + self.assertEqual(self.refpkg_dict[rp_name].num_seqs, + len(fasta.get_headers(file_path))) + + return + + def test_bin_hmm_matches(self): + self.assertTrue(False) + return + if __name__ == '__main__': unittest.main() diff --git a/tests/test_classy.py b/tests/test_classy.py index 75b38d84..b487673d 100644 --- a/tests/test_classy.py +++ b/tests/test_classy.py @@ -69,11 +69,12 @@ def test_furnish_with_arguments(self): args.input = [self.fasta] args.output = self.output_dir args.molecule = "prot" - args.executables = {'prodigal': '/home/connor/bin/prodigal', 'BMGE.jar': '/usr/local/bin/BMGE.jar', + args.executables = {'prodigal': '/home/connor/bin/prodigal', 'hmmbuild': '/usr/local/bin/hmmbuild', 'hmmalign': '/usr/local/bin/hmmalign', 'hmmsearch': '/usr/local/bin/hmmsearch', - 'epa-ng': '/usr/local/bin/epa-ng', 'raxml-ng': '/usr/local/bin/raxml-ng'} + 'epa-ng': '/usr/local/bin/epa-ng', + 'raxml-ng': '/usr/local/bin/raxml-ng'} self.db.furnish_with_arguments(args) self.assertEqual(len(args.executables), len(self.db.executables)) self.assertEqual(self.fasta, self.db.input_sequences) diff --git a/tests/test_clipkit_helper.py b/tests/test_clipkit_helper.py new file mode 100644 index 00000000..ea7f8de8 --- /dev/null +++ b/tests/test_clipkit_helper.py @@ -0,0 +1,43 @@ +import os +import unittest + +from .testing_utils import get_test_data + + +class MyTestCase(unittest.TestCase): + def setUp(self) -> None: + self.test_fa = get_test_data('PuhA.mfa') + self.output_fa = 'PuhA.trim.mfa' + + def tearDown(self) -> None: + if os.path.isfile(self.output_fa): + os.remove(self.output_fa) + + def test_run(self): + from treesapp import clipkit_helper + from clipkit import modes as ck_modes + ck = clipkit_helper.ClipKitHelper(fasta_in=self.test_fa, + output_dir='./', + mode="smart-gap", + min_len=200) + ck.run() + ck.compare_original_and_trimmed_multiple_alignments() + ck.summarise_trimming() + self.assertTrue(os.path.isfile(self.output_fa)) + + ck.mode = ck_modes.TrimmingMode("kpi-smart-gap") + ck.run() + ck.compare_original_and_trimmed_multiple_alignments() + ck.summarise_trimming() + self.assertTrue(os.path.isfile(self.output_fa)) + + ck.mode = ck_modes.TrimmingMode("gappy") + ck.run() + ck.compare_original_and_trimmed_multiple_alignments() + ck.summarise_trimming() + self.assertTrue(os.path.isfile(self.output_fa)) + return + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_commands.py b/tests/test_commands.py index 5ca63fc3..4b9b1376 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -106,11 +106,17 @@ def test_assign(self): "--placement_summary", "aelw", "--trim_align", "--svm"] assign.assign(assign_commands_list) - self.assertEqual(13, len(read_classification_table(assignments_tbl))) + assigned_queries = 15 + self.assertEqual(assigned_queries, + len(read_classification_table(assignments_tbl))) self.assertTrue(os.path.isfile(classified_seqs_faa)) - assign.assign(assign_commands_list + ["--targets", "McrA,McrB,XmoA"]) - self.assertEqual(18, len(read_classification_table(assignments_tbl))) - self.assertEqual(18, len(fasta.get_headers(classified_seqs_faa))) + assign.assign(assign_commands_list + ["--targets", "McrA,McrB,XmoA", + "--min_seq_length", str(30)]) + assigned_queries = 17 + self.assertEqual(assigned_queries, + len(read_classification_table(assignments_tbl))) + self.assertEqual(assigned_queries, + len(fasta.get_headers(classified_seqs_faa))) # Test nucleotide sequence input WITHOUT targets listed assign_commands_list = ["--fastx_input", self.nt_test_fa, @@ -124,8 +130,10 @@ def test_assign(self): "--reads", get_test_data("SRR3669912_1.fastq"), "--reverse", get_test_data("SRR3669912_2.fastq")] assign.assign(assign_commands_list) + assigned_queries = 8 lines = read_classification_table(assignments_tbl) - self.assertEqual(8, len(lines)) + self.assertEqual(assigned_queries, + len(lines)) classified_seqs = set() pqueries = assignments_to_pqueries(lines) for rp_pqs in pqueries.values(): diff --git a/tests/test_file_parsers.py b/tests/test_file_parsers.py index a18ad007..86393570 100644 --- a/tests/test_file_parsers.py +++ b/tests/test_file_parsers.py @@ -134,45 +134,6 @@ def test_read_lineage_ids(self): read_lineage_ids(get_test_data("McrA_lineage_ids - GTDB_map.tsv")) return - def test_check_seq_name_integer_compatibility(self): - from treesapp import file_parsers - msa, n_refs = file_parsers.check_seq_name_integer_compatibility(seq_dict=self.test_fasta_data) - self.assertEqual(-1, n_refs) - self.assertEqual("Bad_header_name", msa.popitem()[0]) - return - - def test_validate_alignment_trimming(self): - from treesapp import file_parsers - from treesapp import fasta - test_msa = get_test_data("PuhA.mfa") - headers = set([str(i) + "_PuhA" for i in range(1, 48)]) - tmp_fasta = fasta.FASTA(file_name=test_msa) - tmp_fasta.load_fasta() - - # Test bad file extension - with pytest.raises(SystemExit): - file_parsers.validate_alignment_trimming(["PuhA.stk"], headers) - - # Fail due to a bad sequence name in a fasta - with pytest.raises(SystemExit): - file_parsers.validate_alignment_trimming([self.test_data_file], set(self.test_fasta_data)) - - # Ensure success - success, fail, msg = file_parsers.validate_alignment_trimming(msa_files=[test_msa], - unique_ref_headers=set(tmp_fasta.get_seq_names())) - self.assertEqual(32, len(success[test_msa])) - self.assertEqual([], fail) - self.assertIsInstance(msg, str) - - # MSA fails due to more sequence names in unique_ref_headers than MSA - success, fail, msg = file_parsers.validate_alignment_trimming(msa_files=[test_msa], - unique_ref_headers=headers) - self.assertEqual({}, success) - self.assertEqual([test_msa], fail) - self.assertIsInstance(msg, str) - - return - if __name__ == '__main__': unittest.main() diff --git a/tests/test_graftm_utils.py b/tests/test_graftm_utils.py index f2bb3fe0..86536c79 100644 --- a/tests/test_graftm_utils.py +++ b/tests/test_graftm_utils.py @@ -71,7 +71,7 @@ def test_prep_graftm_ref_files(self): from treesapp import utilities # Find the executables exe_map = {} - for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft", "BMGE.jar"]: + for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft"]: exe_map[dep] = utilities.fetch_executable_path(dep, self.ts_dir) taxon_str = 'd__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; o__Rhizobiales' diff --git a/tests/test_multiple_alignment.py b/tests/test_multiple_alignment.py new file mode 100644 index 00000000..7e030d80 --- /dev/null +++ b/tests/test_multiple_alignment.py @@ -0,0 +1,39 @@ +import os +import unittest + +from .testing_utils import get_test_data + + +class MyTestCase(unittest.TestCase): + def test_trim_multiple_alignments(self): + from treesapp import multiple_alignment + from treesapp import refpkg + test_fa = get_test_data('PuhA.mfa') + trim_file = os.path.join("tests", "test_data", "PuhA.trim.mfa") + qc_file = os.path.join("tests", "test_data", "PuhA.trim.qc.mfa") + test_rp = refpkg.ReferencePackage(refpkg_name="PuhA") + test_rp.f__pkl = get_test_data(filename=os.path.join("refpkgs", "PuhA_build.pkl")) + test_rp.slurp() + + result = multiple_alignment.trim_multiple_alignment_farmer([{"qry_ref_mfa": test_fa, + "refpkg_name": "PuhA", + "gap_tuned": True, + "avg_id": 88}], + min_seq_length=10, + n_proc=1, + ref_pkgs={"PuhA": test_rp}, + for_placement=False) + self.assertTrue(os.path.isfile(trim_file)) + self.assertIsInstance(result, dict) + self.assertTrue("PuhA" in result.keys()) + self.assertEqual(os.path.basename(qc_file), + os.path.basename(result["PuhA"].pop())) + + for f_path in [trim_file, qc_file]: + if os.path.isfile(f_path): + os.remove(f_path) + return + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_placement_trainer.py b/tests/test_placement_trainer.py index 8428db72..e6117324 100644 --- a/tests/test_placement_trainer.py +++ b/tests/test_placement_trainer.py @@ -27,7 +27,7 @@ def setUp(self) -> None: # Executables dictionary self.exes = {} self.treesapp_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) + os.sep - for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft", "BMGE.jar"]: + for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft", "epa-ng"]: self.exes[dep] = fetch_executable_path(dep, get_treesapp_root()) return @@ -80,6 +80,8 @@ def test_clade_exclusion_phylo_placement(self): '634498.mru_1924'], 'd__Bacteria; p__Proteobacteria; c__Gammaproteobacteria': ['523846.Mfer_0784', '79929.MTBMA_c15480']}} + + # This will fail since query sequences (McrA) are unrelated to reference package (PuhA) pqueries = clade_exclusion_phylo_placement(rank_training_seqs=train_seqs, test_fasta=self.bad_fasta, ref_pkg=self.test_refpkg, executables=self.exes, min_seqs=3, output_dir=self.output_dir) diff --git a/tests/test_refpkg.py b/tests/test_refpkg.py index 61077eb8..7ac1fdc5 100644 --- a/tests/test_refpkg.py +++ b/tests/test_refpkg.py @@ -40,8 +40,9 @@ def setUp(self) -> None: # Find the executables self.exe_map = {} self.treesapp_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) + os.sep - for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft", "BMGE.jar", "FastTree"]: + for dep in ["hmmbuild", "hmmalign", "raxml-ng", "mafft", "FastTree"]: self.exe_map[dep] = fetch_executable_path(dep, utils.get_treesapp_root()) + self.n_cpu = utils.NUM_THREADS return def tearDown(self) -> None: @@ -64,6 +65,31 @@ def test_band(self): self.db.band() return + def test_blast(self): + import Bio.Align + import timeit + + qseq = "AMQIGMSFISXYKVCAGEAAVADLAFAAKHAGVIQMADILPARRARGPNEPGGIKFGHFC" \ + "DMIQGDRKYPNDPVRANLEVVAAGAMLFDQIWLGSYMSGGVGFTQYATAAYTDNILDDYC" \ + "EYGVDYIKKKHGGIAKAKSTQEVVSDIATEVNLYGMEQYESYPTALESHFGGSQRASVLA" \ + "AASGLTCSLATANSNAGLNGWYLSMLMHKEGWSRLGFFGYDLQDQCGSANSMSIRPDEGL" \ + "LGELRGPNYPNYAI" + ref_pkg = self.db + + exec_time = timeit.timeit(stmt="ref_pkg.blast(qseq=qseq, n_proc=n_proc)", + globals={'ref_pkg': ref_pkg, + 'qseq': qseq, + 'n_proc': self.n_cpu}, + number=10) + self.assertTrue(0 < exec_time < 10) + + aln, seq_id, g_seq_id = ref_pkg.blast(qseq, + n_proc=self.n_cpu) # type: Bio.Align.PairwiseAlignment + self.assertEqual(100, aln.score) + self.assertEqual(62, round(seq_id, 0)) + self.assertEqual(89, round(g_seq_id, 0)) + return + def test_disband(self): # Basic disband self.db.disband(output_dir="./tests/") diff --git a/tests/test_training_utils.py b/tests/test_training_utils.py index dad1766a..245f8ddd 100644 --- a/tests/test_training_utils.py +++ b/tests/test_training_utils.py @@ -53,7 +53,7 @@ def test_generate_pquery_data_for_trainer(self): treesapp_dir = get_treesapp_root() executables = {} - for dep in ["hmmbuild", "hmmalign", "hmmsearch", "epa-ng", "raxml-ng", "FastTree", "mafft", "BMGE.jar"]: + for dep in ["hmmbuild", "hmmalign", "hmmsearch", "epa-ng", "raxml-ng", "FastTree", "mafft"]: executables[dep] = fetch_executable_path(dep, treesapp_dir) pbar = tqdm() test_taxon_one = "f__Bradyrhizobiaceae; g__Bradyrhizobium; s__Bradyrhizobium 'sp.' BTAi1" @@ -68,10 +68,9 @@ def test_generate_pquery_data_for_trainer(self): def test_fetch_executable_path(self): from treesapp.utilities import fetch_executable_path - from re import sub treesapp_dir = get_treesapp_root() - exe_path = fetch_executable_path("BMGE.jar", treesapp_dir) - self.assertEqual("/sub_binaries/BMGE.jar", sub(treesapp_dir, '', exe_path)) + exe_path = fetch_executable_path("epa-ng", treesapp_dir) + self.assertEqual("epa-ng", os.path.basename(exe_path)) return def test_load_training_data_frame(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 4a61d371..40c40767 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -6,6 +6,9 @@ import ete3 +NUM_THREADS = os.cpu_count() + + def random_ete_tree(leaf_names: list, branch_len_dist=None) -> ete3.Tree: if not branch_len_dist: branch_len_dist = (0, 1) diff --git a/tox.ini b/tox.ini index 642b9b4c..84ebd69c 100644 --- a/tox.ini +++ b/tox.ini @@ -40,7 +40,7 @@ passenv = deps= .[tests] codecov>=2.0.0 - setuptools>=50.3.1 + setuptools>=60.4.0 setenv = {[default]setenv} @@ -76,8 +76,8 @@ skip_install = True deps = codecov>=2.0.0 - coverage==5.0.3 - setuptools>=50.3.1 + coverage>=5.0.3 + setuptools>=60.4.0 setenv = {[default]setenv} diff --git a/treesapp/__init__.py b/treesapp/__init__.py index 9d007770..e08f3305 100644 --- a/treesapp/__init__.py +++ b/treesapp/__init__.py @@ -27,7 +27,7 @@ __status__ = "Production/Stable" __title__ = "TreeSAPP" __url__ = "https://github.com/hallamlab/TreeSAPP" -__version__ = "0.11.4" +__version__ = "0.11.5" __all__ = ['abundance', 'annotate_extra', 'assign', 'clade_annotation', 'classy', 'commands', 'create_refpkg', 'entish', 'entrez_utils', diff --git a/treesapp/assign.py b/treesapp/assign.py index e7805095..40383ca1 100755 --- a/treesapp/assign.py +++ b/treesapp/assign.py @@ -15,20 +15,21 @@ from treesapp import abundance from treesapp import classy +from treesapp import entish +from treesapp import external_command_interface as eci +from treesapp import fasta +from treesapp import file_parsers +from treesapp import lca_calculations as ts_lca from treesapp import logger +from treesapp import multiple_alignment from treesapp import phylo_seq from treesapp import refpkg from treesapp import treesapp_args -from treesapp import entish -from treesapp import lca_calculations as ts_lca from treesapp import jplace_utils -from treesapp import file_parsers from treesapp import phylo_dist from treesapp import utilities from treesapp import wrapper -from treesapp import fasta from treesapp import training_utils -from treesapp import external_command_interface as eci from treesapp.hmmer_tbl_parser import HmmMatch LOGGER = logging.getLogger(logger.logger_name()) @@ -363,37 +364,31 @@ def search(self, ref_pkg_dict: dict, hmm_parsing_thresholds, num_threads=2) -> d return file_parsers.parse_domain_tables(hmm_parsing_thresholds, refpkg_hmmer_tables) - def align(self, refpkg_dict: dict, homolog_seq_files: list, min_seq_length: int, n_proc: int, - trim_align=True, verbose=False) -> dict: + def align(self, refpkg_dict: dict, pquery_groups: list, + min_seq_length: int, n_proc: int, + trim_align=True) -> dict: if self.past_last_stage("align"): return {} MSAs = namedtuple("MSAs", "ref query") - ref_alignment_dimensions = get_alignment_dims(refpkg_dict) split_msa_files = self.fetch_multiple_alignments() target_refpkgs = {prefix: rp for prefix, rp in refpkg_dict.items() if prefix not in split_msa_files} if self.stage_status("align") or target_refpkgs: - # create_ref_phy_files(refpkg_dict, align_output_dir, - # homolog_seq_files, ref_alignment_dimensions) - concatenated_msa_files = multiple_alignments(self.executables, homolog_seq_files, - target_refpkgs, "hmmalign", - output_dir=self.stage_lookup(name="align").dir_path, - num_proc=n_proc, silent=self.silent) - if concatenated_msa_files: + msa_files = multiple_alignments(self.executables, pquery_groups, + target_refpkgs, "hmmalign", + output_dir=self.stage_lookup(name="align").dir_path, + num_proc=n_proc, silent=self.silent) + if msa_files: combined_msa_files = {} - file_type = utilities.find_msa_type(concatenated_msa_files) - alignment_length_dict = get_sequence_counts(concatenated_msa_files, ref_alignment_dimensions, - verbose, file_type) if trim_align: - tool = "BMGE" - trimmed_mfa_files = wrapper.filter_multiple_alignments(self.executables, concatenated_msa_files, - target_refpkgs, n_proc, tool, self.silent) - qc_ma_dict = check_for_removed_sequences(trimmed_mfa_files, concatenated_msa_files, - target_refpkgs, min_seq_length) - evaluate_trimming_performance(qc_ma_dict, alignment_length_dict, concatenated_msa_files, tool) - combined_msa_files.update(qc_ma_dict) + trimmed_mfa_files = multiple_alignment.trim_multiple_alignment_farmer(pquery_groups, + min_seq_length=min_seq_length, + ref_pkgs=refpkg_dict, + n_proc=n_proc, + silent=self.silent) + combined_msa_files.update(trimmed_mfa_files) else: - combined_msa_files.update(concatenated_msa_files) + combined_msa_files.update(msa_files) # Subset the multiple alignment of reference sequences and queries to just contain query sequences for refpkg_name in combined_msa_files: @@ -447,7 +442,7 @@ def write_classified_orfs(self, pqueries: dict, extracted_seqs: dict) -> None: molecule="dna", subset=classified_seq_names, full_name=self.fasta_full_name) - nuc_orfs.header_registry = fasta.register_headers(nuc_orfs.fasta_dict.keys()) + nuc_orfs.header_registry = fasta.register_headers(header_list=list(nuc_orfs.fasta_dict.keys())) nuc_orfs.change_dict_keys() if not os.path.isfile(self.classified_nuc_seqs): self.ts_logger.info("Creating nucleotide FASTA file of classified sequences '{}'... " @@ -568,7 +563,52 @@ def load_homologs(hmm_matches: dict, hmmsearch_query_fasta: str, query_seq_fasta return -def bin_hmm_matches(hmm_matches: dict, fasta_dict: dict) -> (dict, dict): +def bin_hmm_matches_by_region(ref_pkg_hmm_matches: list) -> dict: + """ + Algorithm for binning sequences: + 1. Sort HmmMatches by the proportion of the HMM profile they covered in increasing order (full-length last) + 2. For HmmMatch in sorted matches, determine overlap between HmmMatch and each bin's representative HmmMatch + 3. If overlap exceeds 80% of representative's aligned length add it to the bin, else continue + 4. When bins are exhausted create new bin with HmmMatch + """ + bins = dict() + for hmm_match in sorted(ref_pkg_hmm_matches, key=lambda x: x.end - x.start): # type: HmmMatch + # Add the FASTA record of the trimmed sequence - this one moves on for placement + binned = False + for bin_num in sorted(bins): + bin_rep = bins[bin_num][0] + overlap = min(hmm_match.pend, bin_rep.pend) - max(hmm_match.pstart, bin_rep.pstart) + if (100 * overlap) / (bin_rep.pend - bin_rep.pstart) > 80: # 80 refers to overlap proportion with seed + bins[bin_num].append(hmm_match) + binned = True + break + + if not binned: + bin_num = len(bins) + bins[bin_num] = list() + bins[bin_num].append(hmm_match) + + return bins + + +def bin_hmm_matches_by_identity(ref_pkg_hmm_matches: list, fasta_dict: dict, ref_pkg: refpkg.ReferencePackage) -> dict: + """ + Group query sequences based on their identity to the closest sequence in the reference package. + """ + bins = dict() + for hmm_match in sorted(ref_pkg_hmm_matches, key=lambda x: x.end - x.start): # type: HmmMatch + match_sequence = fasta_dict[hmm_match.sequence_name()][hmm_match.start - 1:hmm_match.end] + _aln, _seq_id, g_seq_id = ref_pkg.blast(qseq=match_sequence) + hmm_match.aln_pident = round(g_seq_id, 2) + # Round using -1 to group into bins of width 10 + try: + bins[int(round(g_seq_id, -1))].append(hmm_match) + except KeyError: + bins[int(round(g_seq_id, -1))] = [hmm_match] + return bins + + +def bin_hmm_matches(hmm_matches: dict, fasta_dict: dict, refpkg_dict: dict, method="region") -> (dict, dict): """ Used for extracting query sequences that mapped to reference package HMM profiles. These are binned into groups based on the location on the HMM profile they mapped to such that MSAs downstream will have more conserved positions @@ -576,130 +616,113 @@ def bin_hmm_matches(hmm_matches: dict, fasta_dict: dict) -> (dict, dict): The first nested dictionary returned "extracted_seq_dict" contains marker (i.e. ref_pkg names) strings mapped to bin numbers mapped to query sequence negative integer code names mapped to their extracted, or sliced, sequence. - The second dictionary returned "numeric_contig_index" is used for mapping query sequence negative integer code names + The second dictionary returned "seq_name_index" is used for mapping query sequence negative integer code names mapped to their original header names with the alignment coordinates appended at the end for each marker. :param hmm_matches: Contains lists of HmmMatch objects mapped to the marker they matched :param fasta_dict: Stores either the original or ORF-predicted input FASTA. Headers are keys, sequences are values + :param refpkg_dict: Dictionary of reference packages indexed by their prefix attibutes. + :param method: How should the sequences be binned? Options are 'region' or 'identity'. :return: List of files that go on to placement stage, dictionary mapping marker-specific numbers to contig names """ LOGGER.info("Extracting and grouping the quality-controlled sequences... ") extracted_seq_dict = dict() # Keys are markers -> bin_num -> negative integers -> extracted sequences numeric_contig_index = dict() # Keys are markers -> negative integers -> headers - bins = dict() + bin_identities = {} for marker in hmm_matches: if len(hmm_matches[marker]) == 0: continue if marker not in numeric_contig_index.keys(): numeric_contig_index[marker] = dict() - numeric_decrementor = -1 if marker not in extracted_seq_dict: extracted_seq_dict[marker] = dict() - # Algorithm for binning sequences: - # 1. Sort HmmMatches by the proportion of the HMM profile they covered in increasing order (full-length last) - # 2. For HmmMatch in sorted matches, determine overlap between HmmMatch and each bin's representative HmmMatch - # 3. If overlap exceeds 80% of representative's aligned length add it to the bin, else continue - # 4. When bins are exhausted create new bin with HmmMatch - for hmm_match in sorted(hmm_matches[marker], key=lambda x: x.end - x.start): # type: HmmMatch - if hmm_match.desc != '-': - contig_name = hmm_match.orf + ' ' + hmm_match.desc - else: - contig_name = hmm_match.orf - # Add the query sequence to the index map - orf_coordinates = str(hmm_match.start) + '_' + str(hmm_match.end) - numeric_contig_index[marker][numeric_decrementor] = contig_name + '|' + marker + '|' + orf_coordinates - # Add the FASTA record of the trimmed sequence - this one moves on for placement - full_sequence = fasta_dict[contig_name] - binned = False - for bin_num in sorted(bins): - bin_rep = bins[bin_num][0] - overlap = min(hmm_match.pend, bin_rep.pend) - max(hmm_match.pstart, bin_rep.pstart) - if (100 * overlap) / (bin_rep.pend - bin_rep.pstart) > 80: # 80 refers to overlap proportion with seed - bins[bin_num].append(hmm_match) - extracted_seq_dict[marker][bin_num][numeric_decrementor] = full_sequence[ - hmm_match.start - 1:hmm_match.end] - binned = True - break - if not binned: - bin_num = len(bins) - bins[bin_num] = list() - extracted_seq_dict[marker][bin_num] = dict() - bins[bin_num].append(hmm_match) - extracted_seq_dict[marker][bin_num][numeric_decrementor] = full_sequence[ - hmm_match.start - 1:hmm_match.end] - numeric_decrementor -= 1 + if method == "region": + binned_matches = bin_hmm_matches_by_region(ref_pkg_hmm_matches=hmm_matches[marker]) + else: + binned_matches = bin_hmm_matches_by_identity(ref_pkg_hmm_matches=hmm_matches[marker], + fasta_dict=fasta_dict, + ref_pkg=refpkg_dict[marker]) - bins.clear() - LOGGER.info("done.\n") + numeric_decrementor = -1 + for bin_num in binned_matches: + if marker not in bin_identities: + bin_identities[marker] = {} + bin_identities[marker][bin_num] = [hmm_match.aln_pident for hmm_match in binned_matches[bin_num]] + for hmm_match in binned_matches[bin_num]: + match_sequence = fasta_dict[hmm_match.sequence_name()][hmm_match.start - 1:hmm_match.end] + # Add the query sequence to the index map + numeric_contig_index[marker][ + numeric_decrementor] = hmm_match.sequence_name() + '|' + marker + '|' + hmm_match.coord_string() + try: + extracted_seq_dict[marker][bin_num][numeric_decrementor] = match_sequence + except KeyError: + extracted_seq_dict[marker][bin_num] = {numeric_decrementor: match_sequence} + numeric_decrementor -= 1 - return extracted_seq_dict, numeric_contig_index + LOGGER.info("done.\n") + return extracted_seq_dict, numeric_contig_index, bin_identities -def write_grouped_fastas(extracted_seq_dict: dict, numeric_contig_index: dict, refpkg_dict: dict, output_dir: str): - hmmalign_input_fastas = list() - bulk_marker_fasta = dict() - bin_fasta = dict() +def summarise_hits_to_groups(extracted_seq_dict): group_size_string = "Number of query sequences in each marker's group:\n" for marker in extracted_seq_dict: for group in sorted(extracted_seq_dict[marker]): if extracted_seq_dict[marker][group]: group_size_string += "\t".join([marker, str(group), str(len(extracted_seq_dict[marker][group]))]) + "\n" LOGGER.debug(group_size_string + "\n") + return + + +def write_grouped_fastas(extracted_seq_dict: dict, seq_name_index: dict, refpkg_dict: dict, output_dir: str) -> dict: + hmmalign_input_fastas = dict() + bulk_marker_fasta = dict() + bin_fasta = dict() + + summarise_hits_to_groups(extracted_seq_dict) LOGGER.info("Writing the grouped sequences to FASTA files... ") for marker in extracted_seq_dict: ref_pkg = refpkg_dict[marker] # type: refpkg.ReferencePackage f_acc = 0 # For counting the number of files for a marker. Will exceed groups if len(queries) > len(references) + f_path = os.path.join(output_dir, "{}_hmm_purified_group{}.faa".format(marker, f_acc)) + hmmalign_input_fastas[marker] = {} for group in sorted(extracted_seq_dict[marker]): if extracted_seq_dict[marker][group]: group_sequences = extracted_seq_dict[marker][group] for num in group_sequences: # Add the query sequence to the master marker FASTA with the full sequence name - bulk_marker_fasta[numeric_contig_index[marker][num]] = group_sequences[num] + bulk_marker_fasta[seq_name_index[marker][num]] = group_sequences[num] # Add the query sequence to this bin's FASTA file bin_fasta[str(num)] = group_sequences[num] # Ensuring the number of query sequences doesn't exceed the number of reference sequences if len(bin_fasta) >= ref_pkg.num_seqs: - fasta.write_new_fasta(bin_fasta, - output_dir + marker + "_hmm_purified_group" + str(f_acc) + ".faa") - hmmalign_input_fastas.append(output_dir + marker + "_hmm_purified_group" + str(f_acc) + ".faa") + fasta.write_new_fasta(bin_fasta, f_path) + hmmalign_input_fastas[marker][f_path] = group f_acc += 1 + f_path = os.path.join(output_dir, "{}_hmm_purified_group{}.faa".format(marker, f_acc)) bin_fasta.clear() if len(bin_fasta) >= 1: - fasta.write_new_fasta(bin_fasta, output_dir + marker + "_hmm_purified_group" + str(f_acc) + ".faa") - hmmalign_input_fastas.append(output_dir + marker + "_hmm_purified_group" + str(f_acc) + ".faa") + fasta.write_new_fasta(bin_fasta, f_path) + hmmalign_input_fastas[marker][f_path] = group + f_acc += 1 + f_path = os.path.join(output_dir, "{}_hmm_purified_group{}.faa".format(marker, f_acc)) f_acc += 1 bin_fasta.clear() # Now write a single FASTA file with all identified markers if len(bulk_marker_fasta) >= 1: - trimmed_hits_fasta = output_dir + marker + "_hmm_purified.faa" + trimmed_hits_fasta = os.path.join(output_dir, marker + "_hmm_purified.faa") fasta.write_new_fasta(bulk_marker_fasta, trimmed_hits_fasta) bulk_marker_fasta.clear() LOGGER.info("done.\n") return hmmalign_input_fastas -def subsequence(fasta_dictionary, contig_name, start, end): - """ - Extracts a sub-sequence from `start` to `end` of `contig_name` in `fasta_dictionary` - with headers for keys and sequences as values. `contig_name` does not contain the '>' character - - :param fasta_dictionary: - :param contig_name: - :param start: - :param end: - :return: A string representing the sub-sequence of interest - """ - subseq = fasta_dictionary['>' + contig_name][start:end] - return subseq - - def get_sequence_counts(concatenated_mfa_files: dict, ref_alignment_dimensions: dict, verbosity: bool, file_type: str): alignment_length_dict = dict() for refpkg_name in concatenated_mfa_files: @@ -781,13 +804,13 @@ def multiple_alignments(executables: dict, query_sequence_files: list, refpkg_di return concatenated_msa_files -def prepare_and_run_hmmalign(execs: dict, single_query_fasta_files: list, refpkg_dict: dict, +def prepare_and_run_hmmalign(execs: dict, pquery_groups_manifest: list, refpkg_dict: dict, output_dir="", n_proc=2, silent=False) -> dict: """ Runs `hmmalign` to add the query sequences into the reference FASTA multiple alignments :param execs: Dictionary of executable file paths indexed by the software names - :param single_query_fasta_files: List of unaligned query sequences in FASTA format + :param pquery_groups_manifest: List of dictionaries :param refpkg_dict: A dictionary of ReferencePackage instances indexed by their respective prefix attributes :param output_dir: Where to write the multiple alignment files containing reference and query sequences :param n_proc: The number of alignment jobs to run in parallel @@ -800,32 +823,27 @@ def prepare_and_run_hmmalign(execs: dict, single_query_fasta_files: list, refpkg task_list = list() # Run hmmalign on each fasta file - for query_fa_in in sorted(single_query_fasta_files): - file_name_info = re.match(r"(.*)_hmm_purified.*\.(f.*)$", os.path.basename(query_fa_in)) - if file_name_info: - refpkg_name, extension = file_name_info.groups() - else: - LOGGER.error("Unable to parse information from file name:" + "\n" + str(query_fa_in) + "\n") - sys.exit(3) + for pquery_group in pquery_groups_manifest: + # Add to the manifest to ensure all files are available to ClipKitHelper + file_prefix, _ext = os.path.splitext(os.path.basename(pquery_group["qry_fa"])) + query_mfa_out = os.path.join(output_dir, file_prefix + ".sto") + pquery_group["qry_ref_mfa"] = os.path.join(output_dir, file_prefix + ".mfa") try: - ref_pkg = refpkg_dict[refpkg_name] # type: refpkg.ReferencePackage + ref_pkg = refpkg_dict[pquery_group["refpkg_name"]] # type: refpkg.ReferencePackage except KeyError: + # Reference packages are provided only for MSAs that need to be processed continue - if ref_pkg.prefix not in hmmalign_singlehit_files: - hmmalign_singlehit_files[ref_pkg.prefix] = [] - - query_mfa_out = os.path.join(output_dir, - re.sub('.' + re.escape(extension) + r"$", ".sto", os.path.basename(query_fa_in))) - try: - mfa_out_dict[ref_pkg.prefix].append(query_mfa_out) - except KeyError: - mfa_out_dict[ref_pkg.prefix] = [query_mfa_out] + # Stash file name in dictionary for quick look-up + mfa_out_dict[query_mfa_out] = pquery_group # Get the paths to either the HMM or CM profile files task_list.append([wrapper.hmmalign_command(execs["hmmalign"], - ref_pkg.f__msa, ref_pkg.f__profile, query_fa_in, query_mfa_out)]) + ref_pkg.f__msa, + ref_pkg.f__profile, + pquery_group["qry_fa"], + query_mfa_out)]) eci.run_apply_async_multiprocessing(func=eci.launch_write_command, arguments_list=task_list, @@ -833,19 +851,21 @@ def prepare_and_run_hmmalign(execs: dict, single_query_fasta_files: list, refpkg pbar_desc="Profile alignment", disable=silent) - for prefix in mfa_out_dict: - for query_mfa_out in mfa_out_dict[prefix]: - mfa_file = re.sub(r"\.sto$", ".mfa", query_mfa_out) - seq_dict = file_parsers.read_stockholm_to_dict(query_mfa_out) - fasta.write_new_fasta(seq_dict, mfa_file) - hmmalign_singlehit_files[prefix].append(mfa_file) - end_time = time.time() hours, remainder = divmod(end_time - start_time, 3600) minutes, seconds = divmod(remainder, 60) LOGGER.debug("\thmmalign time required: " + ':'.join([str(hours), str(minutes), str(round(seconds, 2))]) + "\n") + # Convert from Stockholm to FASTA format + for query_mfa_out, pquery_group in mfa_out_dict.items(): + seq_dict = file_parsers.read_stockholm_to_dict(query_mfa_out) + fasta.write_new_fasta(seq_dict, pquery_group["qry_ref_mfa"]) + try: + hmmalign_singlehit_files[pquery_group["refpkg_name"]].append(pquery_group["qry_ref_mfa"]) + except KeyError: + hmmalign_singlehit_files[pquery_group["refpkg_name"]] = [pquery_group["qry_ref_mfa"]] + return hmmalign_singlehit_files @@ -872,132 +892,6 @@ def gather_split_msa(refpkg_names: list, align_dir: str) -> dict: return split_msa_map -def check_for_removed_sequences(trimmed_msa_files: dict, msa_files: dict, refpkg_dict: dict, min_len=10): - """ - Reads the multiple alignment files (either Phylip or FASTA formatted) and looks for both reference and query - sequences that have been removed. Multiple alignment files are removed from `mfa_files` if: - 1. all query sequences were removed; a DEBUG message is issued - 2. at least one reference sequence was removed - This quality-control function is necessary for placing short query sequences onto reference trees. - - :param trimmed_msa_files: - :param msa_files: A dictionary containing the untrimmed MSA files indexed by reference package code (denominator) - :param refpkg_dict: A dictionary of ReferencePackage objects indexed by their ref_pkg names - :param min_len: The minimum allowable sequence length after trimming (not including gap characters) - :return: A dictionary of denominators, with multiple alignment dictionaries as values. Example: - {M0702: { "McrB_hmm_purified.phy-BMGE.fasta": {'1': seq1, '2': seq2}}} - """ - qc_ma_dict = dict() - num_successful_alignments = 0 - discarded_seqs_string = "" - trimmed_away_seqs = dict() - untrimmed_msa_failed = [] - LOGGER.debug("Validating trimmed multiple sequence alignment files... ") - - for refpkg_name in sorted(trimmed_msa_files.keys()): - ref_pkg = refpkg_dict[refpkg_name] # type: refpkg.ReferencePackage - trimmed_away_seqs[ref_pkg.prefix] = 0 - # Create a set of the reference sequence names - ref_headers = fasta.get_headers(ref_pkg.f__msa) - unique_refs = set([re.sub('_' + re.escape(ref_pkg.prefix), '', x)[1:] for x in ref_headers]) - msa_passed, msa_failed, summary_str = file_parsers.validate_alignment_trimming( - trimmed_msa_files[ref_pkg.prefix], - unique_refs, True, min_len) - - # Report the number of sequences that are removed by BMGE - for trimmed_msa_file in trimmed_msa_files[ref_pkg.prefix]: - try: - prefix = re.search('(' + re.escape(ref_pkg.prefix) + r"_.*_group\d+)-(BMGE|trimAl).fasta$", - os.path.basename(trimmed_msa_file)).group(1) - except TypeError: - LOGGER.error("Unexpected file name format for a trimmed MSA.\n") - sys.exit(3) - # Find the untrimmed query sequence MSA file - the trimmed MSA file's 'pair' - pair = "" - for msa_file in msa_files[ref_pkg.prefix]: - if re.search(re.escape(prefix) + r'\.', msa_file): - pair = msa_file - break - if pair: - if trimmed_msa_file in msa_failed: - untrimmed_msa_failed.append(pair) - trimmed_away_seqs[ref_pkg.prefix] += len( - set(fasta.get_headers(pair)).difference(set(fasta.get_headers(trimmed_msa_file)))) - else: - LOGGER.error("Unable to map trimmed MSA file '" + trimmed_msa_file + "' to its original MSA.\n") - sys.exit(5) - - if len(msa_failed) > 0: - if len(untrimmed_msa_failed) != len(msa_failed): - LOGGER.error("Not all of the failed ({}/{})," - " trimmed MSA files were mapped to their original MSAs." - "\n".format(len(msa_failed), len(trimmed_msa_files[ref_pkg.prefix]))) - sys.exit(3) - untrimmed_msa_passed, _, _ = file_parsers.validate_alignment_trimming(untrimmed_msa_failed, unique_refs, - True, min_len) - msa_passed.update(untrimmed_msa_passed) - num_successful_alignments += len(msa_passed) - qc_ma_dict[ref_pkg.prefix] = msa_passed - discarded_seqs_string += summary_str - untrimmed_msa_failed.clear() - - LOGGER.debug("done.\n") - LOGGER.debug("\tSequences removed during trimming:\n\t\t" + - '\n\t\t'.join([k + ": " + str(trimmed_away_seqs[k]) for k in trimmed_away_seqs.keys()]) + "\n") - - LOGGER.debug("\tSequences <" + str(min_len) + " characters removed after trimming:" + - discarded_seqs_string + "\n") - - if num_successful_alignments == 0: - LOGGER.error("No quality alignment files to analyze after trimming. Exiting now.\n") - sys.exit(0) # Should be 3, but this allows Clade_exclusion_analyzer to continue after exit - - return qc_ma_dict - - -def evaluate_trimming_performance(qc_ma_dict, alignment_length_dict, concatenated_msa_files, tool): - """ - - :param qc_ma_dict: A dictionary mapping denominators to files to multiple alignment dictionaries - :param alignment_length_dict: - :param concatenated_msa_files: Dictionary with markers indexing original (untrimmed) multiple alignment files - :param tool: The name of the tool that was appended to the original, untrimmed or unmasked alignment files - :return: None - """ - trimmed_length_dict = dict() - for denominator in sorted(qc_ma_dict.keys()): - if len(concatenated_msa_files[denominator]) >= 1: - of_ext = concatenated_msa_files[denominator][0].split('.')[-1] - else: - continue - if denominator not in trimmed_length_dict: - trimmed_length_dict[denominator] = list() - for multi_align_file in qc_ma_dict[denominator]: - file_type = multi_align_file.split('.')[-1] - multi_align = qc_ma_dict[denominator][multi_align_file] - num_seqs, trimmed_seq_length = fasta.multiple_alignment_dimensions(multi_align_file, multi_align) - - original_multi_align = re.sub('-' + tool + '.' + file_type, '.' + of_ext, multi_align_file) - raw_align_len = alignment_length_dict[original_multi_align] - diff = raw_align_len - trimmed_seq_length - if diff < 0: - LOGGER.warning("MSA length increased after {} processing for {}\n".format(tool, multi_align_file)) - else: - trimmed_length_dict[denominator].append(diff) - - trimming_performance_string = "\tAverage columns removed:\n" - for denominator in trimmed_length_dict: - trimming_performance_string += "\t\t" + denominator + "\t" - n_trimmed_files = len(trimmed_length_dict[denominator]) - if n_trimmed_files > 0: - trimming_performance_string += str(round(sum(trimmed_length_dict[denominator]) / n_trimmed_files, 1)) + "\n" - else: - trimming_performance_string += str(0.0) + "\n" - - LOGGER.debug(trimming_performance_string + "\n") - return - - def delete_files(clean_up: bool, root_dir: str, section: int) -> None: files_to_be_deleted = [] if clean_up: @@ -1413,6 +1307,30 @@ def alert_for_refpkg_feature_annotations(pqueries: dict, refpkg_dict: dict) -> N return +def build_pquery_group_manifest(bin_identities: dict, file_group_map: dict) -> list: + pquery_group_manifest = [] + for marker in file_group_map: + for file_name, group_name in file_group_map[marker].items(): + manifest = {"group": group_name, + "refpkg_name": marker, + "qry_fa": file_name, + "qry_ref_mfa": '', + "gap_tuned": False, + "avg_id": 0.0} + + try: + id_vals = bin_identities[marker][group_name] + manifest["avg_id"] = round(sum(id_vals)/len(id_vals), 2) + except (KeyError, ZeroDivisionError): + manifest["avg_id"] = 0.0 + + if manifest["avg_id"] >= 0: + manifest["gap_tuned"] = True + + pquery_group_manifest.append(manifest) + return pquery_group_manifest + + def assign(sys_args): # STAGE 1: Prompt the user and prepare files and lists for the pipeline parser = treesapp_args.TreeSAPPArgumentParser(description='Classify sequences through evolutionary placement.') @@ -1461,23 +1379,30 @@ def assign(sys_args): load_homologs(hmm_matches, ts_assign.formatted_input, query_seqs) pqueries = load_pqueries(hmm_matches, query_seqs) query_seqs.change_dict_keys("num_id") - extracted_seq_dict, numeric_contig_index = bin_hmm_matches(hmm_matches, query_seqs.fasta_dict) + extracted_seq_dict, numeric_contig_index, bin_identities = bin_hmm_matches(hmm_matches, + query_seqs.fasta_dict, + refpkg_dict=refpkg_dict, + method="identity") numeric_contig_index = replace_contig_names(numeric_contig_index, query_seqs) - homolog_seq_files = write_grouped_fastas(extracted_seq_dict, numeric_contig_index, - refpkg_dict, ts_assign.stage_lookup("search").dir_path) + homolog_seq_files = write_grouped_fastas(extracted_seq_dict, + seq_name_index=numeric_contig_index, + refpkg_dict=refpkg_dict, + output_dir=ts_assign.stage_lookup("search").dir_path) + pquery_group_manifest = build_pquery_group_manifest(bin_identities, + homolog_seq_files) # TODO: Replace this merge_fasta_dicts_by_index with FASTA - only necessary for writing the classified sequences extracted_seq_dict = fasta.merge_fasta_dicts_by_index(extracted_seq_dict, numeric_contig_index) delete_files(args.delete, ts_assign.stage_lookup("search").dir_path, 1) ts_assign.increment_stage_dir(checkpoint="search") ## - # STAGE 4: Run hmmalign, and optionally BMGE, to produce the MSAs for phylogenetic placement + # STAGE 4: Run hmmalign, and optionally trim, to produce the MSAs for phylogenetic placement ## - split_msa_files = ts_assign.align(refpkg_dict, homolog_seq_files, + split_msa_files = ts_assign.align(refpkg_dict=refpkg_dict, + pquery_groups=pquery_group_manifest, n_proc=n_proc, trim_align=args.trim_align, - min_seq_length=args.min_seq_length, - verbose=args.verbose) + min_seq_length=args.min_seq_length) delete_files(args.delete, ts_assign.stage_lookup("search").dir_path, 2) ts_assign.increment_stage_dir(checkpoint="align") diff --git a/treesapp/classy.py b/treesapp/classy.py index 61d05fa6..946707e7 100644 --- a/treesapp/classy.py +++ b/treesapp/classy.py @@ -480,7 +480,7 @@ def find_executables(self, args) -> dict: :return: exec_paths beings the absolute path to each executable """ exec_paths = dict() - dependencies = ["prodigal", "hmmbuild", "hmmalign", "hmmsearch", "epa-ng", "raxml-ng", "BMGE.jar"] + dependencies = ["prodigal", "hmmbuild", "hmmalign", "hmmsearch", "epa-ng", "raxml-ng"] # Extra executables necessary for certain modes of TreeSAPP if self.command == "abundance": diff --git a/treesapp/clipkit_helper.py b/treesapp/clipkit_helper.py new file mode 100644 index 00000000..938f44b4 --- /dev/null +++ b/treesapp/clipkit_helper.py @@ -0,0 +1,239 @@ +import sys +import time +import logging +import os.path + +from clipkit import clipkit as ck +from clipkit import modes as ck_modes + +from treesapp import logger +from treesapp import fasta +from treesapp import file_parsers +from treesapp import utilities + + +class ClipKitHelper: + CLIPKIT_MODES = set([v.value for _k, v in ck_modes.TrimmingMode._member_map_.items()]) + + def __init__(self, fasta_in: str, output_dir: str, + mode="smart-gap", gap_prop=0.95, min_len=None, for_placement=False): + self.logger = logging.getLogger(logger.logger_name()) + self.input = fasta_in + if not os.path.isfile(self.input): + self.logger.error("ClipKit input file '{}' doesn't exist.".format(self.input)) + sys.exit(1) + + if mode not in self.CLIPKIT_MODES: + self.logger.error("'{}' is not a valid TrimmingMode.\n".format(mode)) + sys.exit(1) + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + + prefix, ext = os.path.splitext(os.path.basename(fasta_in)) + self.mfa_out = os.path.join(output_dir, prefix + ".trim" + ext) + self.qc_mfa_out = os.path.join(output_dir, prefix + ".trim.qc" + ext) + + self.mode = ck_modes.TrimmingMode(mode) + self.gap_prop = gap_prop + + self.ff_in = "fasta" + self.ff_out = "fasta" + self.refpkg_name = '' + self.min_unaligned_seq_length = min_len + + # Attributes used in evaluating trimming performance + self.success = True + self.exec_time = 0 + self.num_msa_seqs = 0 + self.num_msa_cols = 0 + self.num_trim_seqs = 0 + self.num_trim_cols = 0 + self.num_qc_seqs = 0 + self.trim_qc_seqs = [] # These sequences passed the min_unaligned_seq_length filter + + # Specific to MSAs for phylogenetic placement + self.placement = for_placement # Boolean indicating MSA contained ref and query sequences + self.num_queries_failed_trimming = 0 + self.num_refs_failed_trimming = 0 + self.num_queries_failed_qc = 0 + self.num_refs_failed_qc = 0 + self.num_queries_retained = 0 + self.num_refs_retained = 0 + return + + def __str__(self) -> str: + return "ClipKitHelper instance for MSA '{}':\n" \ + "Mode = {}\n" \ + "Gap-proportion = {}\n" \ + "Placement = {}\n" \ + "Execution time = {}s\n" \ + "Success = {}\n".format(os.path.basename(self.input), self.mode, + self.gap_prop, self.placement, + round(self.exec_time, 3), self.success) + + def run(self, verbose=False, force=False) -> None: + if os.path.isfile(self.ff_out) and not force: + return + start_time = time.time() + + # Capture all output from print statements within ClipKit + with utilities.Capturing() as output: + ck.execute(input_file=self.input, + input_file_format=self.ff_in, + output_file=self.mfa_out, + output_file_format=self.ff_out, + gaps=self.gap_prop, + complement=False, + mode=self.mode, + use_log=False) + + self.exec_time = time.time() - start_time + + if verbose: + self.logger.debug('\n'.join(output)) + return + + def summarise_trimming(self): + self.logger.debug("Trimming required {}s".format(round(self.exec_time, 3))) + self.logger.debug("Percentage of alignment trimmed = {}%".format(round((100*self.num_trim_cols) / + self.num_msa_cols), 2)) + if self.num_trim_seqs == 0: + self.logger.warning("No sequences were read from {}.\n".format(self.mfa_out)) + + if self.num_trim_cols < self.min_unaligned_seq_length: + # Throw an error if the final trimmed alignment is shorter than min_seq_length, and therefore empty + self.logger.warning( + "Multiple sequence alignment in {} is shorter than minimum sequence length threshold ({}).\n" + "".format(self.mfa_out, self.min_unaligned_seq_length)) + elif self.num_refs_failed_trimming: + # Testing whether there were more sequences in the untrimmed alignment than the trimmed one + self.logger.warning( + "{} reference sequences in {} were removed during alignment trimming " + + "suggesting either truncated sequences or the initial reference alignment was terrible.\n" + "".format(self.num_refs_failed_trimming, self.mfa_out)) + elif self.num_refs_failed_qc: + self.logger.warning("{} reference sequences in {} were shorter than the minimum character length ({})" + " and removed after alignment trimming.\n" + "".format(self.num_refs_failed_qc, self.mfa_out, self.min_unaligned_seq_length)) + + # Ensure that there is at least 1 query sequence retained after trimming the multiple alignment + elif self.num_queries_retained == 0: + self.logger.warning("No query sequences in {} were retained after trimming.\n".format(self.mfa_out)) + + if self.success is False: + self.logger.debug("The untrimmed MSA will be used instead.\n") + return + + def read_trimmed_msa(self) -> fasta.FASTA: + msa_records = fasta.FASTA(file_name=self.mfa_out) + if self.ff_out == "phylip": + msa_records.fasta_dict = file_parsers.read_phylip_to_dict(self.mfa_out) + elif self.ff_out == "fasta": + msa_records.fasta_dict = fasta.read_fasta_to_dict(self.mfa_out) + else: + self.logger.error("Unsupported file format ('{}') of {}.\n".format(self.ff_out, self.mfa_out)) + sys.exit(1) + + msa_records.header_registry = fasta.register_headers(list(msa_records.fasta_dict.keys()), + drop=True) + return msa_records + + def quantify_refs_and_pqueries(self, unique_ref_headers: set, msa_fasta: fasta.FASTA = None): + if not unique_ref_headers: + return + + if not msa_fasta: + msa_fasta = self.read_trimmed_msa() + + for seq_name in msa_fasta.fasta_dict: + if seq_name[0] == '-': # The negative integers indicate this is a query sequence + if seq_name in self.trim_qc_seqs: + self.num_queries_retained += 1 + else: + self.num_queries_failed_qc += 1 + elif seq_name in unique_ref_headers: + if seq_name in self.trim_qc_seqs: + self.num_refs_retained += 1 + else: + self.num_refs_failed_qc += 1 + else: + raise RuntimeError("Unsure what to do with sequence '{}'.\n".format(seq_name)) + return + + def validate_alignment_trimming(self): + if self.num_trim_seqs == 0: + self.success = False + + if self.num_trim_cols < self.min_unaligned_seq_length: + self.success = False + + if self.num_trim_cols > self.num_msa_cols: + self.logger.warning("MSA length increased after trimming {}\n".format(self.input)) + self.success = False + + if self.placement: + self.validate_placement_trimming() + elif self.num_trim_seqs != self.num_msa_seqs: + self.success = False + elif self.num_qc_seqs != self.num_msa_seqs: + self.success = False + return + + def validate_placement_trimming(self) -> None: + if self.num_queries_retained == 0: + self.success = False + if self.num_refs_failed_trimming > 0: + self.success = False + if self.num_refs_failed_qc > 0: + self.success = False + return + + def quality_control_trimmed_seqs(self) -> None: + """ + Quality control trimmed sequences according to their unaligned lengths. + Those passing this filter are appended to the list self.trim_qc_seqs. + """ + msa_fasta = self.read_trimmed_msa() + msa_fasta.unalign() + for seq_name, seq in msa_fasta.fasta_dict.items(): + if len(seq) >= self.min_unaligned_seq_length: + self.trim_qc_seqs.append(seq_name) + self.num_qc_seqs = len(self.trim_qc_seqs) + return + + def compare_original_and_trimmed_multiple_alignments(self, ref_pkg=None): + """Summarises the number of character positions trimmed and new dimensions between the input and output MSA.""" + + self.num_trim_seqs, self.num_trim_cols = fasta.multiple_alignment_dimensions(self.mfa_out) + self.num_msa_seqs, self.num_msa_cols = fasta.multiple_alignment_dimensions(self.input) + + self.quality_control_trimmed_seqs() + + if self.placement: + # Create a set of the reference sequence names + unique_ref_headers = set(ref_pkg.get_fasta().get_seq_names()) + self.quantify_refs_and_pqueries(unique_ref_headers) + + return + + def get_qc_output_file(self) -> str: + if self.success: + return self.qc_mfa_out + else: + return self.input + + def get_qc_trimmed_fasta(self) -> fasta.FASTA: + if not self.success: + return + + msa_fasta = self.read_trimmed_msa() + msa_fasta.keep_only(header_subset=self.trim_qc_seqs) + return msa_fasta + + def write_qc_trimmed_multiple_alignment(self) -> None: + msa_fasta = self.get_qc_trimmed_fasta() + if not msa_fasta: + return + fasta.write_new_fasta(fasta_dict=msa_fasta.fasta_dict, + fasta_name=self.qc_mfa_out) + return diff --git a/treesapp/commands.py b/treesapp/commands.py index 235686dc..77ecff8a 100644 --- a/treesapp/commands.py +++ b/treesapp/commands.py @@ -31,6 +31,7 @@ from treesapp import create_refpkg as ts_create_mod from treesapp import update_refpkg as ts_update_mod from treesapp import hmmer_tbl_parser +from treesapp import multiple_alignment LOGGER = logging.getLogger(logger.logger_name()) @@ -47,6 +48,7 @@ def info(sys_args): import treesapp import Bio + from clipkit.version import __version__ as ck_version import numpy import packaging import pygtrie @@ -62,6 +64,7 @@ def info(sys_args): # Write the version of all python deps py_deps = {"biopython": Bio.__version__, + "clipkit": ck_version, "ete3": ete3.__version__, "joblib": joblib.__version__, "numpy": numpy.__version__, @@ -634,33 +637,21 @@ def create(sys_args): n_threads=args.num_threads, intermediates_dir=ts_create.var_output_dir) ## - # Optionally trim with BMGE and create the Phylip multiple alignment file + # Optionally trim the multiple sequence alignment create Phylip files ## dict_for_phy = dict() if args.trim_align: - trimmed_mfa_files = wrapper.filter_multiple_alignments(ts_create.executables, - {ts_create.ref_pkg.refpkg_code: - [ts_create.ref_pkg.f__msa]}, - {ts_create.ref_pkg.refpkg_code: - ts_create.ref_pkg}) - trimmed_mfa_file = trimmed_mfa_files[ts_create.ref_pkg.refpkg_code] - unique_ref_headers = set(ref_seqs.fasta_dict.keys()) - qc_ma_dict, failed_trimmed_msa, summary_str = file_parsers.validate_alignment_trimming(trimmed_mfa_file, - unique_ref_headers) - LOGGER.debug("Number of sequences discarded: " + summary_str + "\n") - if len(qc_ma_dict.keys()) == 0: + trimmer = multiple_alignment.trim_multiple_alignment_clipkit(msa_file=ts_create.ref_pkg.f__msa, + ref_pkg=ts_create.ref_pkg, + min_seq_length=args.min_seq_length, + for_placement=False) + trimmer.summarise_trimming() + if not trimmer.success: # At least one of the reference sequences were discarded and therefore this package is invalid. LOGGER.error("Trimming removed reference sequences. This could indicate non-homologous sequences.\n" + "Please improve sequence quality-control and/or rerun without the '--trim_align' flag.\n") sys.exit(13) - elif len(qc_ma_dict.keys()) > 1: - LOGGER.error("Multiple trimmed alignment files are found when only one is expected:\n" + - "\n".join([str(k) + ": " + str(qc_ma_dict[k]) for k in qc_ma_dict])) - sys.exit(13) - # NOTE: only a single trimmed-MSA file in the dictionary - for trimmed_msa_file in qc_ma_dict: - dict_for_phy = qc_ma_dict[trimmed_msa_file] - os.remove(trimmed_msa_file) + dict_for_phy.update(trimmer.get_qc_trimmed_fasta().fasta_dict) else: dict_for_phy.update(ref_seqs.fasta_dict) diff --git a/treesapp/extensions/tree_parsermodule.cpp b/treesapp/extensions/tree_parsermodule.cpp deleted file mode 100644 index 0e0fbe23..00000000 --- a/treesapp/extensions/tree_parsermodule.cpp +++ /dev/null @@ -1,814 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -using namespace std; - -static PyObject *read_the_reference_tree(PyObject *self, PyObject *args); -static PyObject *get_parents_and_children(PyObject *self, PyObject *args); -static PyObject *build_subtrees_newick(PyObject *self, PyObject *args); -static PyObject *lowest_common_ancestor(PyObject *self, PyObject *args); -char *get_node_relationships(char *tree_string); -char *split_tree_string(char *tree_string); - -static char read_the_reference_tree_docstring[] = - "Reads the labelled_tree_file and reformats it for downstream interpretation"; -static char get_parents_and_children_docstring[] = - "Stores the input tree as a binary search tree before recursively finding the children and parent of each node"; -static char build_subtrees_newick_docstring[] = - "Reads the labelled, rooted tree and returns all subtrees in the tree"; -static char lowest_common_ancestor_docstring[] = - "Calculate lowest common ancestor for a set of nodes in a tree"; - -//static PyMethodDef module_methods[] = { -// {"error_out", (PyCFunction)error_out, METH_NOARGS, NULL}, -// {NULL, NULL} -//}; - -static PyMethodDef module_methods[] = { - {"_read_the_reference_tree", - read_the_reference_tree, - METH_VARARGS, - read_the_reference_tree_docstring}, - {"_get_parents_and_children", - get_parents_and_children, - METH_VARARGS, - get_parents_and_children_docstring}, - {"_build_subtrees_newick", - build_subtrees_newick, - METH_VARARGS, - build_subtrees_newick_docstring}, - {"_lowest_common_ancestor", - lowest_common_ancestor, - METH_VARARGS, - lowest_common_ancestor_docstring}, - {NULL, NULL, 0, NULL} -}; - -struct module_state { - PyObject *error; -}; - -#if PY_MAJOR_VERSION >= 3 -#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) -#else -#define GETSTATE(m) (&_state) -static struct module_state _state; -#endif - - -/** - A tree node - */ -struct Link { - long key; - long left; - long right; - Link* next; - Link* previous; -}; - -struct TreeNode { - long key; - TreeNode* left; - TreeNode* right; -}; - -struct CharLink { - char* subtree; - CharLink* next; -}; - - -/** - * Inserts a new Link (with key=newKey) at the head of the linked_list. - */ -void prepend_link(Link*& head, long newKey) { - Link * curr = new Link; - curr->key = newKey; - curr->next = head; - head = curr; -} - -/** - * Recursively searches a subtree for a key. - */ -bool find(long query, TreeNode *& r) { - if (r == NULL) return r; - if (query == r->key) - return true; - - bool lb = find(query, r->left); - bool rb = find(query, r->right); - if (lb || rb ) - return true; - else - return false; -} - - -TreeNode* create_node(long key, TreeNode* l = NULL, TreeNode* r = NULL ) { - TreeNode* curr = new TreeNode; - curr->key = key; - curr->left = l; - curr->right = r; - return curr; -} - - -void add_subtree(CharLink*& head, char*& new_subtree) { - CharLink * curr = new CharLink; - curr->subtree = new_subtree; - curr->next = head; - head = curr; -} - - -void deleteSubtreeList(CharLink*& head) { - if ( head != NULL ) { - deleteSubtreeList( head->next ); - free(head->subtree); - delete head; - head = NULL; - } -} - - -void deleteList(Link*& head) { - if ( head != NULL ) { - deleteList( head->next ); - delete head; - head = NULL; - } -} - - -/** - * Deletes all nodes in the tree rooted at root and sets root to NULL. - */ -void deleteTree( TreeNode*& root ) { - if ( root != NULL ) { - deleteTree( root->left ); - deleteTree( root->right ); - delete root; - root = NULL; - } -} - - -void print_list(Link* head) { - std::cout << std::endl; - for (Link* curr = head; curr != NULL; curr = curr->next){ - printf("key: %ld\tleft: %ld\tright: %ld", - curr->key, curr->left, curr->right); - if (curr->next != NULL) std::cout << "\n"; - } - std::cout << std::endl; -} - - -/** - * Prints out the tree sideways. - */ -void print_tree( TreeNode* root, int d = 0 ) { - if ( root == NULL ) return; - print_tree( root->right, d+1 ); - std::cout << std::setw( 3 * d ) << ""; // output 3 * d spaces - std::cout << root->key << std::endl; - print_tree( root->left, d+1 ); -} - -char const *float_chars = "0.123456789"; -char const *real_number_chars = "-0123456789"; - -/* - Returns 1 if character sub is a substring of super, 0 otherwise - */ -int is_char_substr(char sub, const char * super) { - int x = 0; - while(super[x]) { - if (super[x] == sub) - return 1; - x++; - } - return 0; -} - -/* - Returns the length of the char_array as int - */ -int get_char_array_length(char * char_array) { - int x = 0; - while (char_array[x]) - x++; - return x; -} - -/* - * Append a character array (source) onto another character array (dest) - * start is the position to continue appending on dest - */ -int append_char_array(int start, char * source, char *&dest) { - int i = 0; - while (source[i]) - dest[start++] = source[i++]; - return start; -} - -char reverse_char_array(char * char_array, char *&flipped, int first, int last) { - if (last == -1) - return '\0'; - flipped[first] = reverse_char_array(char_array, flipped, first+1, last-1); - return char_array[last]; -} - -/* - * param comma_separated_string: a character array with commas - * return: a vector of character arrays - */ -std::vector csv_to_list(char * comma_separated_string) { - int i = 0; - int k = 0; - std::vector char_list; - char c_array[10]; - - while (comma_separated_string[i]) { - if (comma_separated_string[i] == ',') { - c_array[k] = '\0'; - char_list.push_back(string (c_array)); - i++; k = 0; - c_array[k] = '\0'; - } - c_array[k] = comma_separated_string[i]; - i++; k++; - } - c_array[k] = '\0'; - char_list.push_back(string (c_array)); - return char_list; -} - - -static PyObject *read_the_reference_tree(PyObject *self, PyObject *args) { - char* reference_tree_file; - if (!PyArg_ParseTuple(args, "s", &reference_tree_file)) { - return NULL; - } - - FILE *reference_tree = fopen(reference_tree_file, "r"); - if (reference_tree == 0) { - printf ("The reference tree file %s could not be opened for reading!\n", reference_tree_file); - exit(0); - } - - int _MAX = 1000; - int tree_len = 0; - int count = 2; - int x = 0; - char *tree_string = (char *) malloc ( _MAX * sizeof(char)); - char count_char[10]; - char c = fgetc(reference_tree); - while ( c != EOF ) { - if (tree_len <= _MAX && tree_len >= _MAX - 100) { - _MAX = _MAX + 1000; - tree_string = (char *) realloc(tree_string, _MAX * sizeof(char)); - } - if (c == ')') { - tree_string[tree_len] = c; - tree_len++; - tree_string[tree_len] = '-'; - tree_len++; - - sprintf(count_char, "%d", count); - count++; - x = 0; - while (count_char[x]) { - c = count_char[x]; - tree_string[tree_len] = c; - tree_len++; - x++; - } - c = fgetc(reference_tree); - } - if (c == ':') { - c = fgetc(reference_tree); - // while c is a substring of float_chars, continue reading in characters - while (is_char_substr(c, float_chars) == 1) - c = fgetc(reference_tree); - } - else { - tree_string[tree_len] = c; - tree_len++; - c = fgetc(reference_tree); - } - - } - fclose(reference_tree); - // Now remove the last node - while ( c != ')') { - tree_string[tree_len + 1] = '\0'; - tree_string[tree_len] = ';'; - tree_len--; - c = tree_string[tree_len]; - } - - return Py_BuildValue("s", tree_string); -}; - - -void get_previous_node(char *&parsed_tree_string, int &end, char *&previous) { - char reversed[10]; - char c = parsed_tree_string[end]; - int i = 0; - // Skip through brackets and commas to the end of the previous node - while (is_char_substr(c, real_number_chars) == 0) { - parsed_tree_string[end--] = '\0'; - c = parsed_tree_string[end]; - } - while (is_char_substr(c, real_number_chars) == 1) { - reversed[i] = c; - parsed_tree_string[end--] = '\0'; - c = parsed_tree_string[end]; - i++; - } - reversed[i] = '\0'; - reverse_char_array(reversed, previous, 0, i); - return; -} - - -void load_linked_list(char * tree_string, Link *&head) { - char c; - int pos = 0; - int i = 0; - int newKey = -1; - int retrace_pos = 0; - int x; - char curr[10]; - char* right = (char*) malloc(20); - char* left = (char*) malloc(20); - - int tree_len = get_char_array_length(tree_string); - char* parsed_tree_string = (char*) malloc(tree_len); - for (x = 0; x < tree_len; x++) - parsed_tree_string[x] = '\0'; - - while (tree_string[pos]) { - c = tree_string[pos]; - parsed_tree_string[retrace_pos++] = c; - if (c == ')') { - // load the next node as curr - c = tree_string[pos+1]; - i = 0; - // Overwrite curr - for (x = 0; x < 10; x++) - curr[x] = '\0'; - while (is_char_substr(c, real_number_chars) == 1) { - curr[i++] = c; - pos++; - c = tree_string[pos+1]; - } - curr[i] = '\0'; - newKey = atoi(curr); - if (newKey == 0) - newKey = -1; - prepend_link(head, newKey); - - // load the previous 2 nodes as children and remove these from the string - get_previous_node(parsed_tree_string, retrace_pos, right); - head->right = atoi(right); - for (x = 0; x < 10; x++) - right[x] = '\0'; - get_previous_node(parsed_tree_string, retrace_pos, left); - head->left = atoi(left); - for (x = 0; x < 10; x++) - left[x] = '\0'; - - // add the current node to the parsed_tree_string - i = 0; - while (curr[i]) - parsed_tree_string[retrace_pos++] = curr[i++]; - } - pos++; - } - free(right); - free(left); - free(parsed_tree_string); -} - - -TreeNode* load_tree_from_list(Link* head, TreeNode*& root, std::stack& merge) { - TreeNode* previous = NULL; - if (head == NULL) { - return previous; - } - previous = load_tree_from_list(head->next, root, merge); - root = create_node(head->key); - - if (previous) { - // If a key is equal to the previous key, then the previous key is a child - if (head->left == previous->key) - root->left = previous; - if (head->right == previous->key) - root->right = previous; - // If neither children are from the previous node, then create new nodes - if (head->left != previous->key) - root->left = create_node(head->left); - if (head->right != previous->key) - root->right = create_node(head->right); - // If a child is equal to a node from a long time ago, in a galaxy far far away... - if (!merge.empty() && head->left == merge.top()->key) { - root->left = merge.top(); - merge.pop(); - } - if (!merge.empty() && head->right == merge.top()->key) { - root->right = merge.top(); - merge.pop(); - } - // If neither the left or right are equal to the previous key, store it as merge - if (head->right != previous->key && head->left != previous->key) { - merge.push(previous); - } - } - else { - root->right = create_node(head->right); - root->left = create_node(head->left); - } - return root; -} - - -int get_children_of_nodes(Link * head, char *&children) { - char * buffer = (char*) malloc (20); - int _MAX = 1000; - children = (char *) malloc (_MAX * sizeof(char)); - int x = 0; - for (Link* curr = head; curr != NULL; curr = curr->next){ - if (x <= _MAX && x >= _MAX - 100) { - _MAX = _MAX + 1000; - children = (char *) realloc (children, _MAX * sizeof(char)); - } - sprintf(buffer, "%ld", curr->key); - x = append_char_array(x, buffer, children); - children[x++] = '='; - sprintf(buffer, "%ld", curr->left); - x = append_char_array(x, buffer, children); - children[x++] = ','; - sprintf(buffer, "%ld", curr->right); - x = append_char_array(x, buffer, children); - - if (curr->next != NULL) children[x++] = ';'; - } - children[x++] = '\n'; - children[x] = '\0'; - - free(buffer); - - return x; -} - - -int get_parents_of_nodes(Link * head, char *&parents) { - char parent_string[100]; - int _MAX = 1000; - parents = (char *) malloc (_MAX * sizeof(char)); - int x; - for (x = 0; x < _MAX; x++) - parents[x] = '\0'; - for (x = 0; x < 100; x++) - parent_string[x] = '\0'; - x = 0; - for (Link* curr = head; curr != NULL; curr = curr->next){ - if (x <= _MAX && x >= _MAX - 100) { - _MAX = _MAX + 1000; - parents = (char *) realloc (parents, _MAX * sizeof(char)); - } - sprintf(parent_string, "%ld:%ld,%ld:%ld", curr->left, curr->key, curr->right, curr->key); - - x = append_char_array(x, parent_string, parents); - - if (curr->next != NULL) parents[x++] = ','; - else parents[x++] = '\n'; - } - parents[x] = '\0'; - - return x; -} - - -/* - Find all of the subtrees in the tree - */ -void get_subtree_of_node(TreeNode* root, CharLink*& head) { - char* buffer; - // Check to see if it is an internal node (key < 0) or a leaf - if (root->left == NULL && root->right == NULL) { - buffer = (char*) malloc(100); - for (int x = 0; x < 100; x++) - buffer[x] = '\0'; - sprintf(buffer, "%ld", root->key); - add_subtree(head, buffer); - return; - } - get_subtree_of_node(root->right, head); - CharLink* right_link = head; - get_subtree_of_node(root->left, head); - CharLink* left_link = head; - - // Join the last two subtrees - int sum_length = get_char_array_length(right_link->subtree) + get_char_array_length(left_link->subtree) + 2; - buffer = (char*) malloc(sum_length); - for (int x = 0; x < sum_length; x++) - buffer[x] = '\0'; - sprintf(buffer, "%s %s", right_link->subtree, left_link->subtree); - add_subtree(head, buffer); - return; -} - -/* - Find all of the subtrees in the tree - */ -void get_newick_subtrees(TreeNode* root, CharLink*& head) { - char* buffer; - // Check to see if it is an internal node (key < 0) or a leaf - if (root->left == NULL && root->right == NULL) { - buffer = (char*) malloc(100); - for (int x = 0; x < 100; x++) - buffer[x] = '\0'; - sprintf(buffer, "%ld", root->key); - add_subtree(head, buffer); - return; - } - get_newick_subtrees(root->right, head); - CharLink* right_link = head; - get_newick_subtrees(root->left, head); - CharLink* left_link = head; - - // Join the last two subtrees - int sum_length = get_char_array_length(right_link->subtree) + get_char_array_length(left_link->subtree) + 20; - buffer = (char*) malloc(sum_length); - for (int x = 0; x < sum_length; x++) - buffer[x] = '\0'; - sprintf(buffer, "(%s,%s)%ld", right_link->subtree, left_link->subtree, root->key); - add_subtree(head, buffer); - return; -} - - -void get_subtree_of_node_helper(TreeNode* root, char *&subtrees, int &len_subtrees, const char delim) { - CharLink * head = NULL; - if (delim == ',') - get_subtree_of_node(root, head); - else - get_newick_subtrees(root, head); - int _MAX = 10000; - subtrees = (char*) malloc(_MAX); - // Parse the CharLink linked-list - while (head) { - if (len_subtrees <= _MAX && len_subtrees >= _MAX - 5000 ) { - _MAX = _MAX + 10000; - subtrees = (char *) realloc (subtrees, _MAX * sizeof(char)); - } - len_subtrees = append_char_array(len_subtrees, head->subtree, subtrees); - if (head->next) subtrees[len_subtrees++] = delim; - head = head->next; - } - subtrees[len_subtrees] = '\0'; - deleteSubtreeList(head); -} - - -int get_node_relationships(char *tree_string, char *&children, char *&parents, char *&subtrees) { - /* - :param tree_string: tree_info['subtree_of_node'][node] - Function loads the whole tree into a tree struct where each node has a child and a parent - Then the tree is queried for its ONE PARENT and potentially MULTIPLE CHILDREN - */ - - // Step 1: Load the tree - Link * linked_list = NULL; - load_linked_list(tree_string, linked_list); -// print_list(linked_list); - - // Step 2: Traverse the linked-list to get parents and children strings for each node - int len_children = get_children_of_nodes(linked_list, children); -// printf("Children:\n%s\n", children); - int len_parents = get_parents_of_nodes(linked_list, parents); -// printf("Parents:\n%s\n", parents); - - // Step 3: Convert the linked-list to a tree structure - TreeNode* root = NULL; - std::stack merge; - load_tree_from_list(linked_list, root, merge); - if (!merge.empty()) { - std::cerr << "ERROR: Stack not empty after merging subtrees!" << std::endl; - print_list(linked_list); - while (!merge.empty()) { - cout << "Not popped: " << merge.top()->key << endl; - merge.pop(); - } - cout << tree_string << endl; - return 0; - } - -// print_tree(root); - - // Step 4: Traverse the tree to get all subtrees - int len_subtrees = 0; - get_subtree_of_node_helper(root, subtrees, len_subtrees, ','); - - //Step 5: Clean up the tree and linked list - deleteTree(root); - deleteList(linked_list); - - return len_children + len_parents + len_subtrees; -} - - -static PyObject *get_parents_and_children(PyObject *self, PyObject *args) { - char* tree_string; - if (!PyArg_ParseTuple(args, "s", &tree_string)) { - return NULL; - } - - char* children; - char* parents; - char* subtrees; - - int length = get_node_relationships(tree_string, children, parents, subtrees); - if (length == 0) - return Py_BuildValue("s", "$"); - - children = (char *) realloc(children, (length + 3)); - - int c_pos = 0; - while (children[c_pos]){ - c_pos++; - } - - int p_pos = 0; - while (parents[p_pos]) - children[c_pos++] = parents[p_pos++]; - - int t_pos = 0; - while (subtrees[t_pos]) - children[c_pos++] = subtrees[t_pos++]; - children[c_pos] = '\0'; - - free(parents); - free(subtrees); - - return Py_BuildValue("s", children); -} - - -TreeNode* lca_helper(TreeNode* root, std::vector node_names, int& acc, long& ancestor) { - if (root == NULL) { - return root; - } - int x = node_names.size();; - int n_contained = 0; - - lca_helper(root->left, node_names, acc, ancestor); - lca_helper(root->right, node_names, acc, ancestor); - - for (int i = 0; i < x; i++) { - long query = atol(node_names[i].c_str()); - // Search through root's subtree for the key: query - if (find(query, root)) - n_contained++; - } - // If n_contained == x, stop accumulating ancestor - if (n_contained == x && ancestor == 0) { - ancestor = acc; - } - acc++; - return root; -} - - -static PyObject *lowest_common_ancestor(PyObject *self, PyObject *args) { - char* tree_string; - char* leaves_strung; - long ancestor = 0; - int acc = 0; - - if (!PyArg_ParseTuple(args, "ss", &tree_string, &leaves_strung)) { - return NULL; - } - - // Get the node numbers to find LCA - std::vector leaves = csv_to_list(leaves_strung); - - // Step 1: Load the tree - Link * linked_list = NULL; - load_linked_list(tree_string, linked_list); - // Step 2: Convert the linked-list to a tree structure - TreeNode* root = NULL; - std::stack merge; - load_tree_from_list(linked_list, root, merge); - // Step 3: lca will return the root node of the LCA node for which all leaves are children - lca_helper(root, leaves, acc, ancestor); - - leaves.clear(); - return Py_BuildValue("i", ancestor); - -} - - -static PyObject *build_subtrees_newick(PyObject *self, PyObject *args) { - /* - Function to parse the rooted, assigned tree and find all subtrees of the inserted node - Algorithm: - 1. Load the tree (load_linked_list and load_tree_from_list) - 2. Recursively build the subtrees from leaves to root in Newick format and load into list (get_newick_subtrees) - 3. Parse the linked list for each subtree, separating them by semicolons - 4. Return string to Python - */ - char* tree_string; - if (!PyArg_ParseTuple(args, "s", &tree_string)) { - return NULL; - } - Link * linked_list = NULL; - load_linked_list(tree_string, linked_list); - - TreeNode* root = NULL; - std::stack merge; - load_tree_from_list(linked_list, root, merge); - if (!merge.empty()) { - std::cerr << "ERROR: Stack not empty after merging subtrees!" << std::endl; - print_list(linked_list); - return 0; - } - char* subtrees; - int len_subtrees = 0; - - get_subtree_of_node_helper(root, subtrees, len_subtrees, ';'); - - deleteTree(root); - deleteList(linked_list); - - return Py_BuildValue("s", subtrees); -} - - -#if PY_MAJOR_VERSION >= 3 - -static int module_traverse(PyObject *m, visitproc visit, void *arg) { - Py_VISIT(GETSTATE(m)->error); - return 0; -} - -static int module_clear(PyObject *m) { - Py_CLEAR(GETSTATE(m)->error); - return 0; -} - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - "_tree_parser", /* m_name */ - "This module provides an interface for parsing Newick formatted trees using C from within TreeSAPP", /* m_doc */ - sizeof(struct module_state), /* m_size */ - module_methods, /* m_methods */ - NULL, /* m_reload */ - module_traverse, /* m_traverse */ - module_clear, /* m_clear */ - NULL, /* m_free */ -}; - -#define INITERROR return NULL - -PyMODINIT_FUNC PyInit__tree_parser(void) - -#else -#define INITERROR return - -PyMODINIT_FUNC -init_tree_parser(void) -#endif -{ -#if PY_MAJOR_VERSION >= 3 - PyObject *m = PyModule_Create(&module_def); -#else - static char module_docstring[] = - "This module provides an interface for parsing Newick formatted trees using C from within TreeSAPP"; - PyObject *m = Py_InitModule3("_tree_parser", module_methods, module_docstring); -#endif - - if (m == NULL) - INITERROR; - struct module_state *st = GETSTATE(m); - - st->error = PyErr_NewException("_tree_parser.Error", NULL, NULL); - if (st->error == NULL) { - Py_DECREF(m); - INITERROR; - } - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} \ No newline at end of file diff --git a/treesapp/external_command_interface.py b/treesapp/external_command_interface.py index 1b3f11f5..70bc5b5b 100644 --- a/treesapp/external_command_interface.py +++ b/treesapp/external_command_interface.py @@ -43,7 +43,7 @@ def launch_write_command(cmd_list, collect_all=True): return stdout, proc.returncode -def run_apply_async_multiprocessing(func, arguments_list: list, num_processes: int, pbar_desc: str, +def run_apply_async_multiprocessing(func, arguments_list, num_processes: int, pbar_desc: str, disable=False) -> list: if len(arguments_list) == 0: return [] @@ -57,7 +57,10 @@ def update(*a): pbar.update() for args in arguments_list: - jobs.append(pool.apply_async(func=func, args=(*args,), callback=update)) + if isinstance(args, list): + jobs.append(pool.apply_async(func=func, args=(*args,), callback=update)) + elif isinstance(args, dict): + jobs.append(pool.apply_async(func=func, kwds=args, callback=update)) pool.close() for job in pbar: diff --git a/treesapp/file_parsers.py b/treesapp/file_parsers.py index e9ee3f96..ae5fb9a4 100644 --- a/treesapp/file_parsers.py +++ b/treesapp/file_parsers.py @@ -545,123 +545,6 @@ def read_stockholm_to_dict(sto_file): return seq_dict -def check_seq_name_integer_compatibility(seq_dict: dict) -> (dict, int): - # Parse the MSA dict and ensure headers are integer-compatible - multi_align = {} - n_msa_refs = 0 - for seq_name, seq in seq_dict.items(): - try: - if int(seq_name) > 0: - n_msa_refs += 1 - except ValueError: - if re.match(r"^_\d+", seq_name): - leaf_num = re.sub("^_", '-', seq_name) - # The section of regular expresion after '_' needs to match denominator and refpkg names - elif re.match(r"^\d+_\w{2,10}$", seq_name): - leaf_num = seq_name.split('_')[0] - else: - return {seq_name: ""}, -1 - if int(leaf_num) > 0: - n_msa_refs += 1 - multi_align[seq_name] = seq - return multi_align, n_msa_refs - - -def validate_alignment_trimming(msa_files: list, unique_ref_headers: set, - queries_mapped=False, min_seq_length=30) -> (dict, list, str): - """ - Parse a list of multiple sequence alignment (MSA) files and determine whether the multiple alignment: - 1. is shorter than the min_seq_length (30 by default) - 2. is missing any reference sequences - The number of query sequences discarded - these may have been added by hmmalign or PaPaRa - is returned via a string - - NOTE: Initially designed for sequence records with numeric names (e.g. >488) but accommodates other TreeSAPP formats - - :param msa_files: A list of either Phylip or FASTA formatted MSA files - :param unique_ref_headers: A set of all headers that were in the untrimmed MSA - :param queries_mapped: Boolean indicating whether sequences should be present in addition to reference sequences. - While query sequences _could_ be identified as any that are not in unique_ref_headers, - queries have names that are negative integers for more rapid and scalable identification - :param min_seq_length: Optional minimum unaligned (no '-'s) length a sequence must exceed to be retained - :return: 1. Dictionary indexed by MSA file name mapping to FASTA-dictionaries - 2. A string mapping the number of query sequences removed from each MSA file - 3. A string describing the number of sequences discarded - """ - discarded_seqs_string = "" - successful_multiple_alignments = dict() - failed_multiple_alignments = list() - n_refs = len(unique_ref_headers) - for multi_align_file in msa_files: - filtered_multi_align = dict() - discarded_seqs = list() - num_queries_retained = 0 - n_retained_refs = 0 - f_ext = multi_align_file.split('.')[-1] - - # Read the multiple alignment file - if re.search("phy", f_ext): # File is in Phylip format - seq_dict = read_phylip_to_dict(multi_align_file) - elif re.match("^f", f_ext): # This is meant to match all fasta extensions - seq_dict = fasta.read_fasta_to_dict(multi_align_file) - elif f_ext == "mfa": # This is meant to match a multiple alignment in FASTA format - seq_dict = fasta.read_fasta_to_dict(multi_align_file) - else: - LOGGER.error("Unable to detect file format of " + multi_align_file + ".\n") - sys.exit(13) - - multi_align, n_msa_refs = check_seq_name_integer_compatibility(seq_dict) - if n_msa_refs < 0: - LOGGER.error("Unexpected sequence name ('{}') detected in {}.\n" - "".format(multi_align.popitem()[0], multi_align_file)) - sys.exit(13) - if len(multi_align) == 0: - LOGGER.warning("No sequences were read from {}. " - "The untrimmed alignment will be used instead.\n".format(multi_align_file)) - failed_multiple_alignments.append(multi_align_file) - continue - # The numeric identifiers make it easy to maintain order in the Phylip file by a numerical sort - for seq_name in sorted(multi_align, key=lambda x: int(x.split('_')[0])): - seq_dummy = re.sub('-', '', multi_align[seq_name]) - if len(seq_dummy) < min_seq_length: - discarded_seqs.append(seq_name) - else: - filtered_multi_align[seq_name] = multi_align[seq_name] - # The negative integers indicate this is a query sequence - if seq_name[0] == '-': - num_queries_retained += 1 - else: - n_retained_refs += 1 - discarded_seqs_string += "\n\t\t" + multi_align_file + " = " + str(len(discarded_seqs)) - if len(discarded_seqs) == len(multi_align.keys()): - # Throw an error if the final trimmed alignment is shorter than min_seq_length, and therefore empty - LOGGER.warning("Multiple sequence alignment in {} is shorter than minimum sequence length threshold ({})." - "\nThe untrimmed MSA will be used instead.\n".format(multi_align_file, min_seq_length)) - failed_multiple_alignments.append(multi_align_file) - elif n_refs > n_msa_refs: - # Testing whether there were more sequences in the untrimmed alignment than the trimmed one - LOGGER.warning("Reference sequences in " + multi_align_file + " were removed during alignment trimming " + - "suggesting either truncated sequences or the initial reference alignment was terrible.\n" + - "The untrimmed alignment will be used instead.\n") - failed_multiple_alignments.append(multi_align_file) - elif n_refs > n_retained_refs: - LOGGER.warning("Reference sequences shorter than the minimum character length ({})" - " in {} were removed after alignment trimming.\n".format(min_seq_length, multi_align_file) + - "The untrimmed alignment will be used instead.\n") - failed_multiple_alignments.append(multi_align_file) - # Ensure that there is at least 1 query sequence retained after trimming the multiple alignment - elif queries_mapped and num_queries_retained == 0: - LOGGER.warning("No query sequences in " + multi_align_file + " were retained after trimming.\n") - else: - successful_multiple_alignments[multi_align_file] = filtered_multi_align - - if multi_align_file in successful_multiple_alignments: - discarded_seqs_string += " (retained)" - else: - discarded_seqs_string += " (removed)" - - return successful_multiple_alignments, failed_multiple_alignments, discarded_seqs_string - - def read_annotation_mapping_file(annot_map_file: str) -> dict: """ Used for reading a file mapping the reference package name to all true positive orthologs in the query input diff --git a/treesapp/hmmer_tbl_parser.py b/treesapp/hmmer_tbl_parser.py index ea745cd9..4d8908ac 100755 --- a/treesapp/hmmer_tbl_parser.py +++ b/treesapp/hmmer_tbl_parser.py @@ -71,6 +71,7 @@ def __init__(self): self.eval = 0.0 # Full-sequence E-value (in the case a sequence alignment is split) self.full_score = 0 self.next_domain = None # The next domain aligned by hmmsearch + self.aln_pident = 0.0 def get_info(self): info_string = "Info for query " + str(self.orf) + ":\n" @@ -85,6 +86,15 @@ def get_info(self): info_string += "\tfull score = " + str(self.full_score) + "\n" return info_string + def coord_string(self, sep='_') -> str: + return str(self.start) + sep + str(self.end) + + def sequence_name(self, sep=' ') -> str: + if self.desc != '-': + return self.orf + sep + self.desc + else: + return self.orf + def subsequent_matches(self): if not self.next_domain: return [self] diff --git a/treesapp/multiple_alignment.py b/treesapp/multiple_alignment.py new file mode 100644 index 00000000..d3b4f7b0 --- /dev/null +++ b/treesapp/multiple_alignment.py @@ -0,0 +1,142 @@ +import sys +import os.path +import time +import logging + +from treesapp import logger +from treesapp import refpkg +from treesapp import external_command_interface as eci +from treesapp import clipkit_helper as ckh + +LOGGER = logging.getLogger(logger.logger_name()) + + +def trim_multiple_alignment_clipkit(msa_file: str, ref_pkg: refpkg.ReferencePackage, + min_seq_length: int, for_placement=False, gap_prop=0.8) -> ckh.ClipKitHelper: + # Modes can be one of 'smart-gap', 'kpi', 'kpic', 'gappy', 'kpi-smart-gap', 'kpi-gappy' + trimmer = ckh.ClipKitHelper(fasta_in=msa_file, + output_dir=os.path.dirname(msa_file), + mode="gappy", + gap_prop=gap_prop, + min_len=min_seq_length, + for_placement=for_placement) + trimmer.refpkg_name = ref_pkg.prefix + trimmer.run() + trimmer.compare_original_and_trimmed_multiple_alignments(ref_pkg) + trimmer.validate_alignment_trimming() + trimmer.write_qc_trimmed_multiple_alignment() + return trimmer + + +def summarise_trimming(msa_trimmers: list) -> None: + """Summarises various outcomes of trimming MSAs.""" + refpkg_trimming_stats = {trimmer.refpkg_name: { + "msa_files": 0, + "cols_removed": [], + "seqs_removed": [], + "successes": 0, + } + for trimmer in msa_trimmers} + num_successful_alignments = 0 + + LOGGER.debug("Validating trimmed multiple sequence alignment files... ") + for trimmer in msa_trimmers: # type: ckh.ClipKitHelper + # Gather all useful stats for each trimmer instance + refpkg_trimming_stats[trimmer.refpkg_name]["msa_files"] += 1 + if trimmer.success: + refpkg_trimming_stats[trimmer.refpkg_name]["successes"] += 1 + num_successful_alignments += 1 + else: + continue + refpkg_trimming_stats[trimmer.refpkg_name]["cols_removed"].append(trimmer.num_msa_cols - trimmer.num_trim_cols) + refpkg_trimming_stats[trimmer.refpkg_name]["seqs_removed"].append(trimmer.num_msa_seqs - trimmer.num_trim_seqs) + + # Summarise trimming by reference package + for refpkg_name, stats in refpkg_trimming_stats.items(): + trim_summary = "\t\t{} trimming stats:\n".format(refpkg_name) + if stats["msa_files"] == 0: + continue + # To avoid ZeroDivisionError + if stats["successes"] > 0: + avg_cols_removed = round(sum(stats["cols_removed"]) / len(stats["cols_removed"])) + avg_seqs_removed = round(sum(stats["seqs_removed"]) / len(stats["seqs_removed"])) + else: + avg_cols_removed = 0 + avg_seqs_removed = 0 + + trim_summary += "Multiple alignment files = {}\n".format(stats["msa_files"]) + trim_summary += "Files successfully trimmed = {}\n".format(stats["successes"]) + trim_summary += "Average columns removed = {}\n".format(avg_cols_removed) + trim_summary += "Average sequences removed = {}\n".format(avg_seqs_removed) + + LOGGER.debug(trim_summary + "\n") + + LOGGER.debug("done.\n") + + if num_successful_alignments == 0: + LOGGER.error("No quality alignment files to analyze after trimming. Exiting now.\n") + sys.exit(0) # This allows Clade_exclusion_analyzer to continue after exit + return + + +def gather_multiple_alignments(msa_trimmers: list) -> dict: + """ + Creates a dictionary of MSA files indexed by reference package names. + These files are trimmed outputs if trimming was successful, or the original if not. + """ + trimmed_output_files = {} + for trimmer in msa_trimmers: # type: ckh.ClipKitHelper + try: + trimmed_output_files[trimmer.refpkg_name].append(trimmer.get_qc_output_file()) + except KeyError: + trimmed_output_files[trimmer.refpkg_name] = [trimmer.get_qc_output_file()] + return trimmed_output_files + + +def trim_multiple_alignment_farmer(pquery_groups_manifest: list, min_seq_length: int, ref_pkgs: dict, + n_proc=1, for_placement=True, silent=False) -> dict: + """ + Runs ClipKit using the provided lists of the concatenated hmmalign files, and the number of sequences in each file. + + :param pquery_groups_manifest: A list of dictionaries with the keys '', ... + :param min_seq_length: Minimum length for a sequence to be retained in the MSA + :param ref_pkgs: A dictionary of reference package names mapped to ReferencePackage instances + :param n_proc: The number of parallel processes to be launched for alignment trimming + :param for_placement: A flag indicating the MSA contains both reference and query sequences + :param silent: A boolean indicating whether the + :return: A list of files resulting from multiple sequence alignment masking. + """ + start_time = time.time() + task_list = list() + hmm_perc = 1.0 + + for pquery_group in pquery_groups_manifest: + trim_args = {"msa_file": pquery_group["qry_ref_mfa"], + "ref_pkg": ref_pkgs[pquery_group["refpkg_name"]], + "min_seq_length": min_seq_length, + "for_placement": for_placement} + + if pquery_group["gap_tuned"]: + trim_args["gap_prop"] = pquery_group["avg_id"]/100 + if trim_args["min_seq_length"] == 0: + trim_args["min_seq_length"] = int(ref_pkgs[pquery_group["refpkg_name"]].hmm_length() * (hmm_perc/100)) + + task_list.append(trim_args) + + msa_trimmers = eci.run_apply_async_multiprocessing(func=trim_multiple_alignment_clipkit, + arguments_list=task_list, + num_processes=n_proc, + pbar_desc="Multiple alignment trimming", + disable=silent) + + end_time = time.time() + hours, remainder = divmod(end_time - start_time, 3600) + minutes, seconds = divmod(remainder, 60) + LOGGER.debug("\tMultiple alignment trimming time required: " + + ':'.join([str(hours), str(minutes), str(round(seconds, 2))]) + "\n") + + summarise_trimming(msa_trimmers) + # Collect the trimmed (or untrimmed if reference sequences were removed) output files + trimmed_output_files = gather_multiple_alignments(msa_trimmers) + + return trimmed_output_files diff --git a/treesapp/refpkg.py b/treesapp/refpkg.py index 006506b0..c0aebcc2 100644 --- a/treesapp/refpkg.py +++ b/treesapp/refpkg.py @@ -7,6 +7,8 @@ from shutil import copy from datetime import datetime as dt +import Bio.Align +from Bio import Align from packaging import version from ete3 import Tree from pandas import DataFrame @@ -1272,6 +1274,60 @@ def deduplicate_annotation_members(self) -> None: return + @staticmethod + def bio_aligner_helper(pw_aligner: Bio.Align.PairwiseAligner, seqA: str, seqB: str): + aln = pw_aligner.align(seqA, seqB)[0] + return aln + + def blast(self, qseq: str, **kwargs) -> (Align.PairwiseAlignment, float, float): + """Find the percent pairwise identity between a query sequence and its closest match in a reference package.""" + aligner = Align.PairwiseAligner(mode="global") + aligner.match_score = kwargs.get('match', 1) + aligner.mismatch_score = kwargs.get('mismatch', 0) + aligner.gap_score = kwargs.get('gap', -10) + aligner.extend_gap_score = kwargs.get('extend_gap', -1) + + def _calculate_identity(sequenceA, sequenceB): + """ + Returns the percentage of identical characters between two sequences. + Assumes the sequences are aligned. + """ + + sa, sb, sl = sequenceA, sequenceB, len(sequenceA) + matches = [sa[i] == sb[i] for i in range(sl)] + seq_id = (100 * sum(matches)) / sl + + gapless_sl = sum([1 for i in range(sl) if (sa[i] != '-' and sb[i] != '-')]) + gap_id = (100 * sum(matches)) / gapless_sl + return (seq_id, gap_id) + + ref_seqs = self.get_fasta() + ref_seqs.unalign() + top_aln = None + + task_list = [] + for sname, sseq in ref_seqs.fasta_dict.items(): + task_list.append({"seqA": sseq, + "seqB": qseq, + "pw_aligner": aligner}) + + results = eci.run_apply_async_multiprocessing(func=self.bio_aligner_helper, + arguments_list=task_list, + num_processes=kwargs.get('n_proc', 1), + pbar_desc="BLAST-ing refpkg", + disable=True) + + for aln in results: + if not top_aln: + top_aln = aln + elif aln.score > top_aln.score: + top_aln = aln + + # Calculate sequence identity + aligned_A , _aln, aligned_B = top_aln.format().split("\n")[:3] + seq_id, g_seq_id = _calculate_identity(aligned_A, aligned_B) + return top_aln, seq_id, g_seq_id + def write_edited_pkl(ref_pkg: ReferencePackage, output_dir: str, overwrite: bool) -> int: if output_dir: diff --git a/treesapp/sub_binaries/BMGE.jar b/treesapp/sub_binaries/BMGE.jar deleted file mode 100755 index 63661ad9..00000000 Binary files a/treesapp/sub_binaries/BMGE.jar and /dev/null differ diff --git a/treesapp/training_utils.py b/treesapp/training_utils.py index 03d54db5..4acc7da7 100644 --- a/treesapp/training_utils.py +++ b/treesapp/training_utils.py @@ -21,10 +21,10 @@ from treesapp import classy from treesapp import phylo_seq from treesapp import logger -from treesapp import external_command_interface as eci from treesapp.jplace_utils import jplace_parser, demultiplex_pqueries, calc_pquery_mean_tip_distances from treesapp.entish import map_internal_nodes_leaves from treesapp.refpkg import ReferencePackage +from treesapp import multiple_alignment LOGGER = logging.getLogger(logger.logger_name()) @@ -365,7 +365,7 @@ def generate_pquery_data_for_trainer(ref_pkg: ReferencePackage, taxon: str, fasta.write_new_fasta(taxonomy_filtered_query_seqs, fasta_name=query_fasta_file) ## - # Run hmmalign, BMGE and EPA-NG to map sequences from the taxonomic rank onto the tree + # Run hmmalign, ClipKit and EPA-NG to map sequences from the taxonomic rank onto the tree ## aln_stdout = wrapper.profile_aligner(executables, ce_refpkg.f__msa, ce_refpkg.f__profile, query_fasta_file, query_sto_file) @@ -375,27 +375,21 @@ def generate_pquery_data_for_trainer(ref_pkg: ReferencePackage, taxon: str, LOGGER.debug(str(aln_stdout) + "\n") - trim_command, combined_msa = wrapper.get_msa_trim_command(executables, all_msa, ce_refpkg.molecule) - eci.launch_write_command(trim_command) - intermediate_files += glob(combined_msa + "*") - - # Ensure reference sequences haven't been removed during MSA trimming - msa_dict, failed_msa_files, summary_str = file_parsers.validate_alignment_trimming([combined_msa], - set(ce_fasta.fasta_dict), - True) - nrow, ncolumn = fasta.multiple_alignment_dimensions(mfa_file=combined_msa, - seq_dict=fasta.read_fasta_to_dict(combined_msa)) - LOGGER.debug("Columns = " + str(ncolumn) + "\n") - if combined_msa not in msa_dict.keys(): + trimmer = multiple_alignment.trim_multiple_alignment_clipkit(msa_file=all_msa, + ref_pkg=ref_pkg, + min_seq_length=int(0.1*ref_pkg.hmm_length()), + for_placement=True) + trimmer.summarise_trimming() + if not trimmer.success: LOGGER.debug("Placements for '{}' are being skipped after failing MSA validation.\n".format(taxon)) for old_file in intermediate_files: - os.remove(old_file) - intermediate_files.clear() + if os.path.isfile(old_file): + os.remove(old_file) + intermediate_files.clear() return pqueries - LOGGER.debug("Number of sequences discarded: " + summary_str + "\n") # Create the query-only FASTA file required by EPA-ng - fasta.split_combined_ref_query_fasta(combined_msa, query_msa, ref_msa) + fasta.split_combined_ref_query_fasta(trimmer.get_qc_output_file(), query_msa, ref_msa) raxml_files = wrapper.raxml_evolutionary_placement(epa_exe=executables["epa-ng"], refpkg_tree=ce_refpkg.f__tree, diff --git a/treesapp/utilities.py b/treesapp/utilities.py index c645ab17..be8a4702 100644 --- a/treesapp/utilities.py +++ b/treesapp/utilities.py @@ -5,9 +5,12 @@ import shutil from glob import glob from csv import Sniffer +from io import StringIO +from functools import partialmethod from pygtrie import StringTrie import multiprocessing +from tqdm import tqdm from treesapp import external_command_interface as eci from treesapp import logger @@ -15,6 +18,21 @@ LOGGER = logging.getLogger(logger.logger_name()) +class Capturing(list): + def __enter__(self): + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + return self + + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + del self._stringio # free up some memory + sys.stdout = self._stdout + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) + return + + def base_file_prefix(file_path: str) -> str: return os.path.splitext(os.path.basename(file_path))[0] @@ -192,8 +210,6 @@ def executable_dependency_versions(exe_dict: dict) -> str: versions_dict[exe] = stdout.strip() elif exe == "FastTree": stdout, returncode = eci.launch_write_command([exe_dict[exe], "-expert"]) - elif exe == "BMGE.jar": - stdout, returncode = eci.launch_write_command(["java", "-Xmx10m", "-jar", exe_dict[exe], "-?"]) else: LOGGER.warning("Unknown version command for " + exe + ".\n") continue diff --git a/treesapp/wrapper.py b/treesapp/wrapper.py index 94f0519a..6e92ea4b 100644 --- a/treesapp/wrapper.py +++ b/treesapp/wrapper.py @@ -1,6 +1,5 @@ import sys import os -import time import re import glob import logging @@ -314,8 +313,9 @@ def raxml_evolutionary_placement(epa_exe: str, refpkg_tree: str, refpkg_msa: str '-t', refpkg_tree, '-q', query_msa, "--model", refpkg_model, - "--no-pre-mask", + # "--no-pre-mask", "--dyn-heur", str(0.9), + # "--baseball-heur", str(0.9), # "--fix-heur", str(0.2), "--preserve-rooting", "on", "--filter-min-lwr", str(0.01), @@ -339,27 +339,6 @@ def raxml_evolutionary_placement(epa_exe: str, refpkg_tree: str, refpkg_msa: str return epa_files -def trimal_command(executable, mfa_file, trimmed_msa_file): - trim_command = [executable, - '-in', mfa_file, - '-out', trimmed_msa_file, - '-gt', str(0.02)] - return trim_command - - -def bmge_command(executable, mfa_file, trimmed_msa_file, molecule): - if molecule == "prot": - bmge_settings = ["-t", "AA", "-m", "BLOSUM30"] - else: - bmge_settings = ["-t", "DNA", "-m", "DNAPAM100:2"] - trim_command = ["java", "-Xmx512m", "-jar", executable] - trim_command += bmge_settings - trim_command += ["-g", "0.99:0.33"] # Specifying the gap rate per_sequence:per_character - trim_command += ['-i', mfa_file, - '-of', trimmed_msa_file] - return trim_command - - def hmmalign_command(executable, ref_aln, ref_profile, input_fasta, output_multiple_alignment): malign_command = [executable, '--mapali', ref_aln, @@ -787,69 +766,3 @@ def run_odseq(odseq_exe: str, fasta_in: str, outliers_fa: str, num_threads: int) return - -def get_msa_trim_command(executables, mfa_file, molecule, tool="BMGE"): - """ - Trims/masks/filters the multiple sequence alignment using either BMGE or trimAl - - :param executables: A dictionary mapping software to a path of their respective executable - :param mfa_file: Name of a MSA file - :param molecule: prot | dna - :param tool: Name of the software to use for trimming [BMGE|trimAl] - Returns file name of the trimmed multiple alignment file in FASTA format - """ - f_ext = mfa_file.split('.')[-1] - if not re.match("mfa|fasta|phy|fa", f_ext): - LOGGER.error("Unsupported file format: '" + f_ext + "'\n") - sys.exit(5) - - trimmed_msa_file = '.'.join(mfa_file.split('.')[:-1]) + '-' + re.escape(tool) + ".fasta" - if tool == "trimAl": - trim_command = trimal_command(executables["trimal"], mfa_file, trimmed_msa_file) - elif tool == "BMGE": - trim_command = bmge_command(executables["BMGE.jar"], mfa_file, trimmed_msa_file, molecule) - else: - LOGGER.error("Unsupported trimming software requested: '" + tool + "'") - sys.exit(5) - - return trim_command, trimmed_msa_file - - -def filter_multiple_alignments(executables, concatenated_mfa_files, refpkg_dict, n_proc=1, tool="BMGE", silent=False): - """ - Runs BMGE using the provided lists of the concatenated hmmalign files, and the number of sequences in each file. - - :param executables: A dictionary mapping software to a path of their respective executable - :param concatenated_mfa_files: A dictionary containing f_contig keys mapping to a FASTA or Phylip sequential file - :param refpkg_dict: A dictionary of ReferencePackage instances indexed by their respective denominators - :param n_proc: The number of parallel processes to be launched for alignment trimming - :param tool: The software to use for alignment trimming - :return: A list of files resulting from BMGE multiple sequence alignment masking. - """ - start_time = time.time() - task_list = list() - trimmed_output_files = {} - - for refpkg_code in sorted(concatenated_mfa_files.keys()): - if refpkg_code not in trimmed_output_files: - trimmed_output_files[refpkg_code] = [] - mfa_files = concatenated_mfa_files[refpkg_code] - for concatenated_mfa_file in mfa_files: - trim_command, trimmed_msa_file = get_msa_trim_command(executables, concatenated_mfa_file, - refpkg_dict[refpkg_code].molecule, tool) - trimmed_output_files[refpkg_code].append(trimmed_msa_file) - task_list.append([trim_command]) - - eci.run_apply_async_multiprocessing(func=eci.launch_write_command, - arguments_list=task_list, - num_processes=n_proc, - pbar_desc="Multiple alignment trimming", - disable=silent) - - end_time = time.time() - hours, remainder = divmod(end_time - start_time, 3600) - minutes, seconds = divmod(remainder, 60) - LOGGER.debug("\t" + tool + " time required: " + - ':'.join([str(hours), str(minutes), str(round(seconds, 2))]) + "\n") - return trimmed_output_files -