Skip to content

Commit

Permalink
Fixed issue #40; updated parse_influenza_blast_results.py
Browse files Browse the repository at this point in the history
* subtype prediction based off majority H/N prediction of all BLAST hits instead of just the top X matches
* the top hit for H/N can also be a user-specified sequence without subtype information
* top segment matches are now sorted by sample name, segment name and BLAST bitscore
  • Loading branch information
peterk87 committed Aug 16, 2023
1 parent 7cf9230 commit 7e5ce5f
Showing 1 changed file with 53 additions and 57 deletions.
110 changes: 53 additions & 57 deletions bin/parse_influenza_blast_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

pl.enable_string_cache(True)

influenza_segment = {
SEGMENT_NAMES = {
1: "1_PB2",
2: "2_PB1",
3: "3_PA",
Expand All @@ -35,27 +35,42 @@
8: "8_NS",
}

METADATA_COLUMNS = [
("#Accession", str),
("Release_Date", pl.Categorical),
("Genus", pl.Categorical),
("Length", pl.UInt16),
("Genotype", str),
("Segment", pl.Categorical),
("Publications", str),
("Geo_Location", pl.Categorical),
("Host", pl.Categorical),
("Isolation_Source", pl.Categorical),
("Collection_Date", pl.Categorical),
("GenBank_Title", str),
]

# Column names/types/final report names
blast_cols = [
BLASTN_COLUMNS = [
("qaccver", str),
("saccver", str),
("pident", float),
("length", pl.UInt16),
("mismatch", pl.UInt16),
("gapopen", pl.UInt16),
("qstart", pl.UInt16),
("qend", pl.UInt16),
("sstart", pl.UInt16),
("send", pl.UInt16),
("length", pl.UInt32),
("mismatch", pl.UInt32),
("gapopen", pl.UInt32),
("qstart", pl.UInt32),
("qend", pl.UInt32),
("sstart", pl.UInt32),
("send", pl.UInt32),
("evalue", pl.Float32),
("bitscore", pl.Float32),
("qlen", pl.UInt16),
("slen", pl.UInt16),
("qlen", pl.UInt32),
("slen", pl.UInt32),
("qcovs", pl.Float32),
("stitle", str),
]

blast_results_report_columns = [
BLAST_RESULTS_REPORT_COLUMNS = [
("sample", "Sample"),
("sample_segment", "Sample Genome Segment Number"),
("#Accession", "Reference NCBI Accession"),
Expand Down Expand Up @@ -83,7 +98,7 @@
("Release_Date", "Reference Release Date"),
]

subtype_results_summary_columns = [
SUBTYPE_RESULTS_SUMMARY_COLUMNS = [
"sample",
"Genotype",
"H_top_accession",
Expand All @@ -96,7 +111,7 @@
"N_NCBI_Influenza_DB_proportion_matches",
]

columns_H_summary_results = [
H_COLUMNS = [
"sample",
"Genotype",
"H_top_accession",
Expand All @@ -117,7 +132,7 @@
"H_virus_name",
]

columns_N_summary_results = [
N_COLUMNS = [
"sample",
"Genotype",
"N_top_accession",
Expand All @@ -138,7 +153,7 @@
"N_virus_name",
]

subtype_results_summary_final_names = {
SUBTYPE_RESULTS_SUMMARY_FINAL_NAMES = {
"sample": "Sample",
"Genotype": "Subtype Prediction",
"N_type": "N: type prediction",
Expand Down Expand Up @@ -196,8 +211,8 @@ def parse_blast_result(
blast_result,
has_header=False,
separator="\t",
new_columns=[name for name, coltype in blast_cols],
dtypes=dict(blast_cols),
new_columns=[name for name, coltype in BLASTN_COLUMNS],
dtypes=dict(BLASTN_COLUMNS),
)
.filter(
(pl.col("pident") >= (pident_threshold * 100))
Expand Down Expand Up @@ -242,11 +257,11 @@ def parse_blast_result(
for seg in segments
]
df_top_seg_matches = pl.concat(dfs, how="vertical")
cols = pl.Series([x for x, _ in blast_results_report_columns])
cols = pl.Series([x for x, _ in BLAST_RESULTS_REPORT_COLUMNS])
df_top_seg_matches = df_top_seg_matches.select(pl.col(cols))
subtype_results_summary = {"sample": sample_name}
if not get_top_ref:
df_genotype_genus = df_top_seg_matches.select(pl.col(["Genotype", "Genus"]))
df_genotype_genus = df_merge.select(pl.col(["Genotype", "Genus"]))
# where the genus is not IAV, set the genotype to "Not IAV"
df_genotype_genus = df_genotype_genus.with_columns(
pl.when(pl.col("Genus") == "Alphainfluenzavirus")
Expand Down Expand Up @@ -322,22 +337,12 @@ def find_h_or_n_type(df_merge, seg, is_iav):
logging.info(
f"{h_or_n}{top_type} n={top_type_count}/{total_count} ({top_type_count / total_count:.1%})"
)
df_segment = df_segment.with_columns(
pl.lit(
df_segment["Genotype"]
.str.contains(f".*{reg_h_or_n_type}" + top_type + r".*")
.fill_null(False)
.alias("type_mask")
)
)
df_seg_top_type = df_segment.filter(pl.col("type_mask") == True).drop("type_mask")
top_result: pl.Series = list(df_seg_top_type.head(1).iter_rows(named=True))[0]
else:
top_type = "N/A"
top_type_count = "N/A"
total_count = "N/A"
top_result: pl.Series = list(df_segment.head(1).iter_rows(named=True))[0]

top_result: pl.Series = list(df_segment.head(1).iter_rows(named=True))[0]
results_summary = {
f"{h_or_n}_type": top_type if is_iav else "N/A",
f"{h_or_n}_sample_segment_length": top_result["qlen"],
Expand Down Expand Up @@ -419,23 +424,28 @@ def report(
sample = subtype_results_summary["sample"]
all_subtype_results[sample] = subtype_results_summary
df_all_blast = pd.concat(dfs_blast).rename(
columns=dict(blast_results_report_columns)
columns=dict(BLAST_RESULTS_REPORT_COLUMNS)
)
df_subtype_results = pd.DataFrame(all_subtype_results).transpose()
cols = pd.Series(subtype_results_summary_columns)
cols = pd.Series(SUBTYPE_RESULTS_SUMMARY_COLUMNS)
cols = cols[cols.isin(df_subtype_results.columns)]
df_subtype_predictions = df_subtype_results[cols].rename(
columns=subtype_results_summary_final_names
columns=SUBTYPE_RESULTS_SUMMARY_FINAL_NAMES
)
cols = pd.Series(columns_H_summary_results)
cols = pd.Series(H_COLUMNS)
cols = cols[cols.isin(df_subtype_results.columns)]
df_H = df_subtype_results[cols].rename(columns=subtype_results_summary_final_names)
cols = pd.Series(columns_N_summary_results)
df_H = df_subtype_results[cols].rename(columns=SUBTYPE_RESULTS_SUMMARY_FINAL_NAMES)
cols = pd.Series(N_COLUMNS)
cols = cols[cols.isin(df_subtype_results.columns)]
df_N = df_subtype_results[cols].rename(columns=subtype_results_summary_final_names)
# Add segment name for more informative
df_N = df_subtype_results[cols].rename(columns=SUBTYPE_RESULTS_SUMMARY_FINAL_NAMES)
# Convert segment number to segment name (1 -> "1_PB2")
df_all_blast["Sample Genome Segment Number"] = df_all_blast["Sample Genome Segment Number"]. \
apply(lambda x: influenza_segment[int(x)])
apply(lambda x: SEGMENT_NAMES[int(x)])
# Sort by sample names, segment numbers and bitscore
df_all_blast = df_all_blast.sort_values(
["Sample", "Sample Genome Segment Number", "BLASTN Bitscore"],
ascending=[True, True, False]
)
write_excel(
[
("Subtype Predictions", df_subtype_predictions),
Expand All @@ -447,7 +457,7 @@ def report(
)
else:
df_blast, subtype_results_summary = results[0]
df_blast = df_blast.rename(mapping=dict(blast_results_report_columns))
df_blast = df_blast.rename(mapping=dict(BLAST_RESULTS_REPORT_COLUMNS))
df_ref_id = df_blast.select(
pl.col([
'Sample',
Expand All @@ -465,30 +475,16 @@ def report(
.alias('Reference NCBI Accession')
)
df_ref_id = df_ref_id.with_columns(
pl.col("Sample Genome Segment Number").apply(lambda x: influenza_segment[int(x)])
pl.col("Sample Genome Segment Number").apply(lambda x: SEGMENT_NAMES[int(x)])
.alias("Sample Genome Segment Number"))
df_ref_id.write_csv(sample_name + ".topsegments.csv", separator=",", has_header=True)


def read_refseq_metadata(flu_metadata):
md_cols = [
("#Accession", str),
("Release_Date", pl.Categorical),
("Genus", pl.Categorical),
("Length", pl.UInt16),
("Genotype", str),
("Segment", pl.Categorical),
("Publications", str),
("Geo_Location", pl.Categorical),
("Host", pl.Categorical),
("Isolation_Source", pl.Categorical),
("Collection_Date", pl.Categorical),
("GenBank_Title", str),
]
return pl.read_csv(
flu_metadata,
has_header=True,
dtypes=dict(md_cols),
dtypes=dict(METADATA_COLUMNS),
)


Expand Down

0 comments on commit 7e5ce5f

Please sign in to comment.