From 56f70788292a675ea5ee262ca344287d17fab94b Mon Sep 17 00:00:00 2001 From: Manuel Holtgrewe Date: Fri, 5 Jan 2024 16:29:36 +0100 Subject: [PATCH] feat: add "serve" command (#24) (#35) --- Makefile | 3 + chew/cli.py | 44 ++++ chew/common.py | 50 +++++ chew/serve.py | 475 +++++++++++++++++++++++++++++++++++++++++ chew/stats.py | 137 +++++++----- requirements/base.txt | 1 + requirements/dev.txt | 1 + requirements/serve.txt | 3 + setup.py | 4 +- 9 files changed, 668 insertions(+), 50 deletions(-) create mode 100644 chew/serve.py create mode 100644 requirements/serve.txt diff --git a/Makefile b/Makefile index b86ce25..8e86bc7 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ .PHONY: default default: +.PHONY: format +format: black isort + .PHONY: black black: black -l 100 . diff --git a/chew/cli.py b/chew/cli.py index 5701026..d341e12 100644 --- a/chew/cli.py +++ b/chew/cli.py @@ -1,9 +1,17 @@ import typing import click +from logzero import logger from chew import __version__, compare, fingerprint, plot_compare, plot_var_het, stats +try: + import dash # noqa + + have_dash_installed = True +except ImportError: + have_dash_installed = False + @click.group() @click.version_option(__version__) @@ -171,3 +179,39 @@ def cli_plot_var_het( stats_out=stats_out, ) plot_var_het.run(config) + + +if have_dash_installed: + + @cli.command("serve", help="Run report server") # type: ignore[attr-defined] + @click.option( + "--annos-tsv", + default=None, + required=False, + help="Optional TSV file with further annotations", + ) + @click.option("--strip-suffix", default="", help="Suffix to strip from sample names") + @click.argument("cohort_ped") + @click.argument("fingerprints", nargs=-1) + @click.pass_context + def cli_serve( + ctx: click.Context, + annos_tsv: typing.Optional[str], + strip_suffix: str, + cohort_ped: str, + fingerprints: typing.List[str], + ): + if not fingerprints: + logger.warn("No fingerprints given!") + return + + from chew import serve + + config = serve.Config( + verbosity=2 if ctx.obj["verbose"] else 1, + strip_suffix=strip_suffix, + cohort_ped=cohort_ped, + fingerprints=fingerprints, + annos_tsv=annos_tsv, + ) + serve.run(config) diff --git a/chew/common.py b/chew/common.py index e3dc593..3e9596d 100644 --- a/chew/common.py +++ b/chew/common.py @@ -1,5 +1,6 @@ """Commonly used code""" +import enum import gzip import os import typing @@ -67,6 +68,55 @@ } +class Sex(enum.Enum): + UNKNOWN = "unknown" + MALE = "male" + FEMALE = "female" + + +PED_SEX_MAP = { + "0": Sex.UNKNOWN, + "1": Sex.MALE, + "2": Sex.FEMALE, +} + + +class DiseaseState(enum.Enum): + UNKNOWN = "unknown" + UNAFFECTED = "affected" + AFFECTED = "unaffected" + + +PED_DISEASE_MAP = { + "0": DiseaseState.UNKNOWN, + "1": DiseaseState.UNAFFECTED, + "2": DiseaseState.AFFECTED, +} + + +@attrs.frozen +class PedigreeMember: + family_name: str + name: str + father: str + mother: str + sex: Sex + disease_state: DiseaseState + + +def pedigree_member_from_tsv(arr: typing.List[str]) -> PedigreeMember: + if len(arr) < 6: + raise Exception("TSV array must have at least 6 fields") + return PedigreeMember( + family_name=arr[0], + name=arr[1], + father=arr[2], + mother=arr[3], + sex=PED_SEX_MAP.get(arr[4], Sex.UNKNOWN), + disease_state=PED_DISEASE_MAP.get(arr[5], DiseaseState.UNKNOWN), + ) + + @attrs.frozen class Site: chrom: str diff --git a/chew/serve.py b/chew/serve.py new file mode 100644 index 0000000..9aab0ca --- /dev/null +++ b/chew/serve.py @@ -0,0 +1,475 @@ +"""Implementation of the serve command""" + + +import csv +import typing + +import attrs +import cattrs +from dash import Dash, Input, Output, dash_table, dcc, html +from dash.dash_table.Format import Format, Scheme +import dash_bootstrap_components as dbc +from logzero import logger +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm + +from chew.common import ( + CHROM_LENS_GRCH37, + CHROM_LENS_GRCH38, + PedigreeMember, + pedigree_member_from_tsv, +) +from chew.stats import compute_sample_stats, extract_header, load_fingerprint_all + + +@attrs.frozen +class ChromDosage: + sample_name: str + chr_1: float + chr_2: float + chr_3: float + chr_4: float + chr_5: float + chr_6: float + chr_7: float + chr_8: float + chr_9: float + chr_10: float + chr_11: float + chr_12: float + chr_13: float + chr_14: float + chr_15: float + chr_16: float + chr_17: float + chr_18: float + chr_19: float + chr_20: float + chr_21: float + chr_22: float + chr_x: float + chr_y: float + + +def compute_chrom_dosages(container) -> ChromDosage: + samtools_idxstats = container["samtools_idxstats"].tolist() + counts = {} + for line in samtools_idxstats.splitlines(): + arr = line.split("\t") + counts[arr[0]] = int(arr[2]) + total = sum(counts.values()) + + def to_key(s: str) -> str: + if s.startswith("chr"): + return s.replace("chr", "chr_").lower() + else: + return f"chr_{s}".lower() + + kwargs = { + to_key(k): v / total + for k, v in counts.items() + if k in CHROM_LENS_GRCH37 or k in CHROM_LENS_GRCH38 + } + header = extract_header(container) + return ChromDosage(sample_name=header.sample, **kwargs) + + +@attrs.frozen +class Config: + verbosity: int + #: Path to cohort-wide PED file + cohort_ped: str + #: Optional path to annotation TSV file + annos_tsv: typing.Optional[str] + #: List of paths to fingerprint ``.npz`` files + fingerprints: typing.List[str] + #: Suffix to strip from sample names + strip_suffix: str + + +def load_all_stats(config: Config): + logger.info("Loading all fingerprint files") + all_stats = [] + for container in map(load_fingerprint_all, tqdm(config.fingerprints)): + all_stats.append(compute_sample_stats(container)) + return all_stats + + +def load_all_chrom_dosages(config: Config): + logger.info("Loading all chrom dosage data") + chrom_dosages = [] + for container in map(load_fingerprint_all, tqdm(config.fingerprints)): + chrom_dosages.append(compute_chrom_dosages(container)) + return chrom_dosages + + +def load_ped_file(config: Config) -> typing.List[PedigreeMember]: + logger.info("Load pedigree file") + result = [] + with open(config.cohort_ped, "rt") as inputf: + reader = csv.reader(inputf, delimiter="\t") + for row in reader: + result.append(pedigree_member_from_tsv(row)) + return result + + +COLUMN_LABELS = { + "sample_name": "Sample", + "sample": "Sample", + "father": "Father", + "mother": "Mother", + "affected": "Affected", + "sex": "Sex", + "release": "Genome", + "hom_refs": "# of 0/0 SNPs", + "hets": "# of 0/1 SNPs", + "hom_alts": "# of 1/1 SNPs", + "mask_ones": "covered SNPs", + "var_het": "var(het)", + "chrx_het_hom": "het/hom calls on chrX", + "chrx_frac": "fraction of chrX reads", + "chry_frac": "fraction of chrY reads", +} + +COLUM_FORMATS = { + "var_het": {"type": "numeric", "format": Format(precision=4, scheme=Scheme.fixed)}, + "chrx_het_hom": {"type": "numeric", "format": Format(precision=3, scheme=Scheme.fixed)}, + "chrx_frac": {"type": "numeric", "format": Format(precision=3, scheme=Scheme.fixed)}, + "chry_frac": {"type": "numeric", "format": Format(precision=3, scheme=Scheme.fixed)}, +} + +SCATTER_DIMS = [ + "var_het", + "chrx_het_hom", + "chrx_frac", + "chry_frac", + "hom_refs", + "sex", + "hets", + "hom_alts", + "mask_ones", +] + +SIDEBAR_STYLE = { + "position": "fixed", + "top": 0, + "left": 0, + "bottom": 0, + "width": "16rem", + "padding": "2rem 1rem", + "background-color": "#f8f9fa", +} + +CONTENT_STYLE = { + "margin-left": "18rem", + "margin-right": "2rem", + "padding": "2rem 1rem", +} + + +def build_select_div( + field_id: str, + label: str, + initial_value: str, + anno_dims, +) -> html.Div: + return html.Div( + [ + dbc.Select( + options=[ + { + "label": COLUMN_LABELS.get(dim, dim), + "value": dim, + } + for dim in SCATTER_DIMS + anno_dims + ], + value=initial_value, + id=field_id, + ), + html.Label( + [label], + htmlFor=field_id, + ), + ], + className="form-floating col-12", + ) + + +def unpivot(df, key="key", variable="variable", value="value"): + n, k = df.shape + data = { + value: df.to_numpy().ravel("F"), + variable: np.asarray(df.columns).repeat(n), + key: np.tile(np.asarray(df.index), k), + } + return pd.DataFrame(data, columns=[key, variable, value]) + + +def run(config: Config): + logger.info("Running report server...") + logger.info("") + all_stats = load_all_stats(config) + df_stats = pd.DataFrame.from_records( + list(map(cattrs.unstructure, all_stats)), + ) + all_chrom_dosages = load_all_chrom_dosages(config) + df_chrom_dosages = pd.DataFrame.from_records(list(map(cattrs.unstructure, all_chrom_dosages))) + df_stats["sample_name"] = df_stats["sample_name"].str.replace(config.strip_suffix, "") + + # Load pedigree + logger.info("Loading pedigree file...") + pedigree = load_ped_file(config) + df_ped = pd.DataFrame.from_records(list(map(cattrs.unstructure, pedigree))) + df_ped["name"] = df_ped["name"].str.replace(config.strip_suffix, "") + df_ped["father"] = df_ped["father"].str.replace(config.strip_suffix, "") + df_ped["mother"] = df_ped["mother"].str.replace(config.strip_suffix, "") + + df = df_ped.set_index("name").join(df_stats.set_index("sample_name")).reset_index() + + if config.annos_tsv: + logger.info("Loading annotations TSV %s", config.annos_tsv) + df_anno = pd.read_csv(config.annos_tsv, sep="\t") + df_anno_0 = df_anno.columns[0] + df_anno[df_anno_0] = df_anno[df_anno_0].str.replace(config.strip_suffix, "") + df = df.set_index("name").join(df_anno.set_index(df_anno_0)).reset_index() + anno_dims = df_anno.columns.to_list()[1:] + else: + anno_dims = [] + + logger.info("Data frame looks like %s", df) + + app = Dash(external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True) + sidebar = html.Div( + [ + html.H2("ngs-chew", className="display-6"), + html.Hr(), + dbc.Nav( + [ + dbc.NavLink("Stats Table", href="/", active="exact"), + dbc.NavLink("Stats Plots: X vs. Y", href="/stats-plots", active="exact"), + dbc.NavLink("Dosage per Chrom.", href="/chrom-dosage-plots", active="exact"), + ], + vertical=True, + pills=True, + ), + ], + style=SIDEBAR_STYLE, + ) + + content = html.Div(id="page-content", style=CONTENT_STYLE) + + app.layout = html.Div([dcc.Location(id="url"), sidebar, content]) + + @app.callback( + Output("stats-scatter-plot", "children"), + [ + Input("stats-scatter-select-dim-h", "value"), + Input("stats-scatter-select-dim-v", "value"), + Input("stats-scatter-select-dim-color", "value"), + Input("stats-scatter-select-dim-symbol", "value"), + Input("stats-scatter-select-marker-size", "value"), + ], + ) + def render_scatter_plot(dim_h, dim_v, color, symbol, marker_size): + if np.issubdtype(df[dim_h].dtype, np.number): + px_func = px.scatter + px_kwargs = {"symbol": symbol} + else: + px_func = px.strip + px_kwargs = {} + fig = px_func( + df, + x=dim_h, + y=dim_v, + color=color, + hover_data=["name", "sex"], + labels={ + dim_h: COLUMN_LABELS.get(dim_h, dim_h), + dim_v: COLUMN_LABELS.get(dim_v, dim_v), + }, + **px_kwargs, + ) + fig.update_traces(marker={"size": int(marker_size)}) + return dcc.Graph(figure=fig) + + @app.callback( + Output("chrom-dosage-plots", "children"), + [ + Input("chrom-dosage-select-dim-color", "value"), + Input("chrom-dosage-select-dim-symbol", "value"), + Input("chrom-dosage-select-marker-size", "value"), + ], + ) + def render_chrom_dosage_plot(color, symbol, marker_size): + df_chrom_dosage_plot_tmp = unpivot( + df_chrom_dosages.set_index("sample_name"), + "sample_name", + "chrom", + "dosage", + ) + df_chrom_dosage_plot_tmp["chrom"] = df_chrom_dosage_plot_tmp["chrom"].str.upper() + df_chrom_dosage_plot_tmp["chrom"] = df_chrom_dosage_plot_tmp["chrom"].str.replace( + "CHR_", "" + ) + df_chrom_dosage_plot = ( + df_ped.set_index("name") + .join(df_chrom_dosage_plot_tmp.set_index("sample_name")) + .reset_index() + .rename(columns={"index": "name"}) + ) + if config.annos_tsv: + logger.info("Loading annotations TSV %s", config.annos_tsv) + df_anno = pd.read_csv(config.annos_tsv, sep="\t") + df_anno_0 = df_anno.columns[0] + df_anno[df_anno_0] = df_anno[df_anno_0].str.replace(config.strip_suffix, "") + df_chrom_dosage_plot = ( + df_chrom_dosage_plot.set_index("name") + .join(df_anno.set_index(df_anno_0)) + .reset_index() + .rename(columns={"index": "name"}) + ) + fig = px.strip( + df_chrom_dosage_plot, + x="chrom", + y="dosage", + color=color, + hover_data=["name", "dosage"], + ) + fig.update_traces(marker={"size": int(marker_size)}) + return dcc.Graph(figure=fig) + + @app.callback(Output("page-content", "children"), [Input("url", "pathname")]) + def render_page_content(pathname): + if pathname == "/": + display_df = df.rename(columns=COLUMN_LABELS) + columns = [ + { + "name": COLUMN_LABELS.get(col, col), + "id": COLUMN_LABELS.get(col, col), + **COLUM_FORMATS.get(col, {}), + } + for col in df.columns + ] + return dash_table.DataTable( + display_df.to_dict("records"), + columns, + sort_action="native", + filter_action="native", + page_action="native", + page_current=0, + page_size=20, + ) + elif pathname == "/stats-plots": + row_select = html.Form( + [ + build_select_div( + "stats-scatter-select-dim-h", "horizontal axis", "sex", anno_dims + ), + build_select_div( + "stats-scatter-select-dim-v", "vertical axis", "chrx_het_hom", anno_dims + ), + build_select_div("stats-scatter-select-dim-color", "color", "sex", anno_dims), + build_select_div("stats-scatter-select-dim-symbol", "shape", "sex", anno_dims), + html.Div( + [ + dbc.Select( + options=[ + {"label": str(2 * i), "value": str(2 * i)} for i in range(1, 11) + ], + value="6", + id="stats-scatter-select-marker-size", + ), + html.Label( + ["size"], + htmlFor="stats-scatter-select-marker-size", + ), + ], + className="form-floating col-12", + ), + ], + className="row row-cols-lg-auto g-3 align-items-center", + ) + row_plot = html.Div( + [ + html.Div( + [ + dbc.Placeholder(xs=6), + html.Br(), + dbc.Placeholder(className="w-75"), + html.Br(), + dbc.Placeholder(style={"width": "25%"}), + ], + id="stats-scatter-plot", + className="col-12", + ) + ], + className="row", + ) + return html.Div( + [ + row_select, + row_plot, + ] + ) + elif pathname == "/chrom-dosage-plots": + row_select = html.Form( + [ + build_select_div("chrom-dosage-select-dim-color", "color", "sex", anno_dims), + build_select_div("chrom-dosage-select-dim-symbol", "shape", "sex", anno_dims), + html.Div( + [ + dbc.Select( + options=[ + {"label": str(2 * i), "value": str(2 * i)} for i in range(1, 11) + ], + value="6", + id="chrom-dosage-select-marker-size", + ), + html.Label( + ["size"], + htmlFor="chrom-dosage-select-marker-size", + ), + ], + className="form-floating col-12", + ), + ], + className="row row-cols-lg-auto g-3 align-items-center", + ) + row_plot = html.Div( + [ + html.Div( + [ + dbc.Placeholder(xs=6), + html.Br(), + dbc.Placeholder(className="w-75"), + html.Br(), + dbc.Placeholder(style={"width": "25%"}), + ], + id="chrom-dosage-plots", + className="col-12", + ) + ], + className="row", + ) + return html.Div( + [ + row_select, + row_plot, + ] + ) + # If the user tries to reach a different page, return a 404 message + return html.Div( + [ + html.H1("404: Not found", className="text-danger"), + html.Hr(), + html.P(f"The pathname {pathname} was not recognised..."), + ], + className="p-3 bg-light rounded-3", + ) + + app.run_server( + dev_tools_hot_reload=True, + debug=True, + ) diff --git a/chew/stats.py b/chew/stats.py index b07561b..d5bf93a 100644 --- a/chew/stats.py +++ b/chew/stats.py @@ -66,11 +66,88 @@ def parse_samtools_idxstats(samtools_idxstats: str) -> typing.Tuple[float, float return chrx_count / total_count, chry_count / total_count +def compute_chrx_het_hom(container): + chrx_fingerprint = container["chrx_fingerprint"] + chrx_mask = chrx_fingerprint[0] + chrx_is_alt = chrx_fingerprint[1] + chrx_hom_alt = chrx_fingerprint[2] + num_homs = np.count_nonzero(chrx_hom_alt & chrx_mask) + np.count_nonzero( + ~chrx_is_alt & chrx_mask + ) + if num_homs > 0: + return np.count_nonzero(chrx_is_alt & chrx_mask) / num_homs + else: + return None + + +def compute_autosomal_aafs(container): + autosomal_fingerprint = container["autosomal_fingerprint"] + autosomal_mask = autosomal_fingerprint[0] + autosomal_is_alt = autosomal_fingerprint[1] + autosomal_hom_alt = autosomal_fingerprint[2] + autosomal_aafs = container["autosomal_aafs"] + is_het = autosomal_mask & autosomal_is_alt & ~autosomal_hom_alt + sqrt_var_het = autosomal_aafs[is_het] - 0.5 + return np.sum(sqrt_var_het * sqrt_var_het) / sqrt_var_het.shape[0] + + +@attrs.frozen +class SampleStats: + sample_name: str + release: str + hets: int + hom_alts: int + hom_refs: int + mask_ones: int + var_het: typing.Optional[float] + chrx_het_hom: typing.Optional[float] + chrx_frac: typing.Optional[float] + chry_frac: typing.Optional[float] + + +def compute_sample_stats(container) -> SampleStats: + header = extract_header(container) + + autosomal_fingerprint = container["autosomal_fingerprint"] + autosomal_mask = autosomal_fingerprint[0] + autosomal_is_alt = autosomal_fingerprint[1] + autosomal_hom_alt = autosomal_fingerprint[2] + + if "autosomal_aafs" in header.fields: + var_het = compute_autosomal_aafs(container) + else: + var_het = None + + if "chrx_aafs" in header.fields: + chrx_het_hom = compute_chrx_het_hom(container) + else: + chrx_het_hom = None + + if "samtools_idxstats" in header.fields: + chrx_frac, chry_frac = parse_samtools_idxstats(str(container["samtools_idxstats"])) + else: + chrx_frac = None + chry_frac = None + + return SampleStats( + sample_name=header.sample, + release=header.release, + hets=np.count_nonzero(autosomal_is_alt & autosomal_mask), + hom_alts=np.count_nonzero(autosomal_hom_alt & autosomal_mask), + hom_refs=np.count_nonzero(~autosomal_is_alt & autosomal_mask), + mask_ones=np.count_nonzero(autosomal_mask), + var_het=var_het, + chrx_het_hom=chrx_het_hom, + chrx_frac=chrx_frac, + chry_frac=chry_frac, + ) + + def run(config: Config): logger.info("Writing statistics file...") with open(config.output, "wt") as outputf: - header_lines = [ - "sample", + col_headers = [ + "sample_name", "hets", "hom_alts", "hom_refs", @@ -81,53 +158,15 @@ def run(config: Config): "chry_frac", ] - print("\t".join(header_lines), file=outputf) + print("\t".join(col_headers), file=outputf) for container in map(load_fingerprint_all, tqdm(config.fingerprints)): - header = extract_header(container) - - autosomal_fingerprint = container["autosomal_fingerprint"] - autosomal_mask = autosomal_fingerprint[0] - autosomal_is_alt = autosomal_fingerprint[1] - autosomal_hom_alt = autosomal_fingerprint[2] - - if "autosomal_aafs" in header.fields: - autosomal_aafs = container["autosomal_aafs"] - is_het = autosomal_mask & autosomal_is_alt & ~autosomal_hom_alt - sqrt_var_het = autosomal_aafs[is_het] - 0.5 - var_het = np.sum(sqrt_var_het * sqrt_var_het) / sqrt_var_het.shape[0] - else: - var_het = None - - if "chrx_aafs" in header.fields: - chrx_fingerprint = container["chrx_fingerprint"] - chrx_mask = chrx_fingerprint[0] - chrx_is_alt = chrx_fingerprint[1] - chrx_hom_alt = chrx_fingerprint[2] - num_homs = np.count_nonzero(chrx_hom_alt & chrx_mask) + np.count_nonzero( - ~chrx_is_alt & chrx_mask - ) - if num_homs > 0: - chrx_het_hom = np.count_nonzero(chrx_is_alt & chrx_mask) / num_homs + stats = compute_sample_stats(container) + + row = [] + for key in col_headers: + if getattr(stats, key, None) is None: + row.append("-") else: - chrx_het_hom = None - else: - chrx_het_hom = None - - if "samtools_idxstats" in header.fields: - chrx_frac, chry_frac = parse_samtools_idxstats(str(container["samtools_idxstats"])) - else: - chrx_frac = None - chry_frac = None - - row = [ - header.sample, - np.count_nonzero(autosomal_is_alt & autosomal_mask), - np.count_nonzero(autosomal_hom_alt & autosomal_mask), - np.count_nonzero(~autosomal_is_alt & autosomal_mask), - np.count_nonzero(autosomal_mask), - var_het or "-", - chrx_het_hom or "-", - chrx_frac or "-", - chry_frac or "-", - ] + getattr(stats, key) + print("\t".join(map(str, row)), file=outputf) diff --git a/requirements/base.txt b/requirements/base.txt index 576e6bc..cc5e100 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,3 +8,4 @@ plotly scipy click attrs +cattrs diff --git a/requirements/dev.txt b/requirements/dev.txt index a43a1b7..0f8d3d7 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1 +1,2 @@ -r test.txt +-r serve.txt diff --git a/requirements/serve.txt b/requirements/serve.txt new file mode 100644 index 0000000..9ea6c26 --- /dev/null +++ b/requirements/serve.txt @@ -0,0 +1,3 @@ +dash +plotly +dash-bootstrap-components diff --git a/setup.py b/setup.py index 24c752f..ae6be61 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def parse_requirements(path): test_requirements = parse_requirements("requirements/test.txt") install_requirements = parse_requirements("requirements/base.txt") +serve_requirements = parse_requirements("requirements/serve.txt") def bash_scripts(names): @@ -60,6 +61,8 @@ def bash_scripts(names): description="NGS Chew", entry_points={"console_scripts": (("ngs-chew=chew.cli:cli",),)}, install_requires=install_requirements, + tests_require=test_requirements, + extras_require={"serve": serve_requirements}, license="MIT license", long_description=readme + "\n\n" + history, long_description_content_type="text/markdown", @@ -69,7 +72,6 @@ def bash_scripts(names): packages=find_packages(), package_dir={"chew": "chew"}, test_suite="tests", - tests_require=test_requirements, url="https://github.com/bihealth/ngs-chew", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(),