Skip to content

Commit

Permalink
fix pull_top_ref_id.nf
Browse files Browse the repository at this point in the history
  • Loading branch information
peterk87 committed Jul 11, 2023
1 parent f7acaba commit dedcfb5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
84 changes: 51 additions & 33 deletions bin/parse_influenza_blast_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import click
import numpy as np
import pandas as pd
import polars as pl
from rich.logging import RichHandler
from typing import Dict, List, Optional, Tuple

LOG_FORMAT = "%(asctime)s %(levelname)s: %(message)s [in %(filename)s:%(lineno)d]"
logging.basicConfig(format=LOG_FORMAT, level=logging.INFO)
Expand Down Expand Up @@ -297,7 +297,8 @@ def find_h_or_n_type(df_merge, seg, is_iav):
type_counts = df_segment["Genotype"].value_counts(sort=True)
type_counts = type_counts.filter(~pl.col("Genotype").is_null())
reg_h_or_n_type = "[Hh]" if h_or_n == "H" else "[Nn]"
df_type_counts = type_counts.with_columns(pl.lit(type_counts["Genotype"].str.extract(reg_h_or_n_type + r"(\d+)").alias(type_name)))
df_type_counts = type_counts.with_columns(
pl.lit(type_counts["Genotype"].str.extract(reg_h_or_n_type + r"(\d+)").alias(type_name)))
df_type_counts = df_type_counts.filter(~pl.col(type_name).is_null())
logging.debug(f"{df_type_counts}")
type_to_count = defaultdict(int)
Expand Down Expand Up @@ -369,47 +370,31 @@ def report(
get_top_ref,
sample_name
):
from rich.traceback import install
install(show_locals=True, width=120, word_wrap=True)
logging.basicConfig(
format="%(message)s",
datefmt="[%Y-%m-%d %X]",
level=logging.DEBUG,
handlers=[RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)],
)
init_logging()

logging.info(f'Parsing Influenza metadata file "{flu_metadata}"')

md_cols = [
("#Accession", str),
("Release_Date", pl.Categorical),
("Genus", pl.Categorical),
("Length", pl.UInt16),
("Genotype", str),
("Segment", pl.Categorical),
("Publications", str),
("Geo_Location", pl.Categorical),
("Host", pl.Categorical),
("Isolation_Source", pl.Categorical),
("Collection_Date", pl.Categorical),
("GenBank_Title", str),
]
df_md = pl.read_csv(
flu_metadata,
has_header=True,
dtypes=dict(md_cols),
)
df_md = read_refseq_metadata(flu_metadata)

unique_subtypes = df_md.select("Genotype").unique()
unique_subtypes = unique_subtypes.filter(~pl.col("Genotype").is_null())
logging.info(
f"Parsed Influenza metadata file into DataFrame with n={df_md.shape[0]} rows and n={df_md.shape[1]} columns. There are {len(unique_subtypes)} unique subtypes. "
f"Parsed Influenza metadata file into DataFrame with n={df_md.shape[0]} rows and n={df_md.shape[1]} columns. "
f"There are {len(unique_subtypes)} unique subtypes."
)
regex_subtype_pattern = r"\((H\d+N\d+|" + "|".join(list(unique_subtypes["Genotype"])) + r")\)"
results = [
parse_blast_result(blast_result, df_md, regex_subtype_pattern, get_top_ref, top=top,
pident_threshold=pident_threshold,
min_aln_length=min_aln_length) for blast_result in blast_results]
parse_blast_result(
blast_result,
df_md,
regex_subtype_pattern,
get_top_ref,
top=top,
pident_threshold=pident_threshold,
min_aln_length=min_aln_length
)
for blast_result in blast_results
]

if not get_top_ref:
dfs_blast = []
Expand Down Expand Up @@ -474,6 +459,39 @@ def report(
df_ref_id.write_csv(sample_name + ".topsegments.csv", separator=",", has_header=True)


def read_refseq_metadata(flu_metadata):
md_cols = [
("#Accession", str),
("Release_Date", pl.Categorical),
("Genus", pl.Categorical),
("Length", pl.UInt16),
("Genotype", str),
("Segment", pl.Categorical),
("Publications", str),
("Geo_Location", pl.Categorical),
("Host", pl.Categorical),
("Isolation_Source", pl.Categorical),
("Collection_Date", pl.Categorical),
("GenBank_Title", str),
]
return pl.read_csv(
flu_metadata,
has_header=True,
dtypes=dict(md_cols),
)


def init_logging():
from rich.traceback import install
install(show_locals=True, width=120, word_wrap=True)
logging.basicConfig(
format="%(message)s",
datefmt="[%Y-%m-%d %X]",
level=logging.DEBUG,
handlers=[RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)],
)


def get_col_widths(df, index=False):
"""Calculate column widths based on column headers and contents"""
if index:
Expand Down
3 changes: 1 addition & 2 deletions modules/local/pull_top_ref_id.nf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
process PULL_TOP_REF_ID {
tag "$meta.id"
label 'process_medium'
label 'process_low'

conda (params.enable_conda ? 'conda-forge::python=3.10 conda-forge::biopython=1.80 conda-forge::openpyxl=3.1.0 conda-forge::pandas=1.5.3 conda-forge::rich=12.6.0 conda-forge::typer=0.7.0 conda-forge::xlsxwriter=3.0.8 conda-forge::polars=0.17.9 conda-forge::pyarrow=11.0.0' : null)
if (workflow.containerEngine == 'singularity' && !params.singularity_pull_docker_container) {
Expand All @@ -20,7 +20,6 @@ process PULL_TOP_REF_ID {
script:
"""
parse_influenza_blast_results.py \\
--threads ${task.cpus} \\
--flu-metadata $genomeset \\
--get-top-ref True \\
--top 1 \\
Expand Down

0 comments on commit dedcfb5

Please sign in to comment.