From c854a2e28f3209ecda54a1f2a4d2e49120229f71 Mon Sep 17 00:00:00 2001 From: Clemens Brunner Date: Wed, 20 Sep 2023 15:38:06 +0200 Subject: [PATCH] Add example for staging a custom dataset (#190) --- CHANGELOG.md | 1 + README.md | 2 - docs/examples.md | 191 ++++++++++++++++++++++++++++++++ examples/README.md | 11 +- examples/feature_extraction.py | 24 ---- examples/heartbeat_detection.py | 78 ------------- examples/try_ws_gru_mesa.py | 24 ---- mkdocs.yml | 2 + sleepecg/test/test_examples.py | 33 ------ 9 files changed, 196 insertions(+), 170 deletions(-) create mode 100644 docs/examples.md delete mode 100644 examples/feature_extraction.py delete mode 100644 examples/heartbeat_detection.py delete mode 100644 examples/try_ws_gru_mesa.py delete mode 100644 sleepecg/test/test_examples.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cd1d111d..52bae647 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## [0.5.5] - 2023-06-01 ### Changed - Use absolute imports internally ([#170](https://github.com/cbrnr/sleepecg/pull/170) by [Clemens Brunner](https://github.com/cbrnr)) +- Move usage examples to documentation website ([#190](https://github.com/cbrnr/sleepecg/pull/190) by [Clemens Brunner](https://github.com/cbrnr)) ## [0.5.4] - 2023-04-13 ### Changed diff --git a/README.md b/README.md index 0d723807..27b4c010 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,6 @@ fs = 360 # sampling frequency beats = detect_heartbeats(ecg, fs) # indices of detected heartbeats ``` -More examples are available at https://github.com/cbrnr/sleepecg/tree/main/examples. - ### Dependencies diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 00000000..4ca1c77e --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,191 @@ +# Examples + +## Heartbeat detection + +```python +import matplotlib.pyplot as plt +import numpy as np + +from sleepecg import compare_heartbeats, detect_heartbeats, read_mitdb + +# download and read data +record = list(read_mitdb(records_pattern="234"))[1] + +# detect heartbeats +detection = detect_heartbeats(record.ecg, record.fs) + +# evaluation and visualization +TP, FP, FN = compare_heartbeats(detection, record.annotation, int(record.fs / 10)) + +t = np.arange(len(record.ecg)) / record.fs + +fig, ax = plt.subplots(3, sharex=True, figsize=(10, 8)) + +ax[0].plot(t, record.ecg, color="k", zorder=1, label="ECG") +ax[0].scatter( + record.annotation / record.fs, + record.ecg[record.annotation], + marker="o", + color="g", + s=50, + zorder=2, + label="annotation", +) +ax[0].set_ylabel("raw signal in mV") + +ax[1].eventplot( + detection / record.fs, + linelength=0.5, + linewidth=0.5, + color="k", + zorder=1, + label="detection", +) +ax[1].scatter( + FN / record.fs, + np.ones_like(FN), + marker="x", + color="r", + s=70, + zorder=2, + label="FN", +) +ax[1].scatter( + FP / record.fs, + np.ones_like(FP), + marker="+", + color="orange", + s=70, + zorder=2, + label="FP", +) +ax[1].set_yticks([]) +ax[1].set_ylabel("heartbeat events") + +ax[2].plot( + detection[1:] / record.fs, + 60 / (np.diff(detection) / record.fs), + label="heartrate in bpm", +) +ax[2].set_ylabel("beats per minute") +ax[2].set_xlabel("time in seconds") + +for ax_ in ax.flat: + ax_.legend(loc="upper right") + ax_.grid(axis="x") + +fig.suptitle( + f"Record ID: {record.id}, lead: {record.lead}\n" + + f"Recall: {len(TP) / (len(TP) + len(FN)):.2%}, " + + f"Precision: {len(TP) / (len(TP) + len(FP)):.2%}", +) + +plt.show() +``` + + +## Sleep staging custom data + +This example requires `mne` and `tensorflow` packages. In addition, it uses an example file [`sleep.edf`](https://osf.io/download/mx7av/), which contains ECG data for a whole night. Download and save this file in your working directory before running this example. + +```python +from datetime import datetime, timezone + +from mne.io import read_raw_edf + +import sleepecg + +# load dataset +raw = read_raw_edf("sleep.edf", include="ECG") +raw.set_channel_types({"ECG": "ecg"}) +fs = raw.info["sfreq"] + +# crop dataset (we only want data for the sleep duration) +start = datetime(2023, 3, 1, 23, 0, 0, tzinfo=timezone.utc) +stop = datetime(2023, 3, 2, 6, 0, 0, tzinfo=timezone.utc) +raw.crop((start - raw.info["meas_date"]).seconds, (stop - raw.info["meas_date"]).seconds) + +# get ECG time series as 1D NumPy array +ecg = raw.get_data().squeeze() + +# detect heartbeats +beats = sleepecg.detect_heartbeats(ecg, fs) +sleepecg.plot_ecg(ecg, fs, beats=beats) + +# load SleepECG classifier (requires tensorflow) +clf = sleepecg.load_classifier("wrn-gru-mesa", "SleepECG") + +# predict sleep stages +record = sleepecg.SleepRecord( + sleep_stage_duration=30, + recording_start_time=start, + heartbeat_times=beats / fs, +) + +stages = sleepecg.stage(clf, record, return_mode="prob") + +sleepecg.plot_hypnogram( + record, + stages, + stages_mode=clf.stages_mode, + merge_annotations=True, +) +``` + + +## Feature extraction + +```python +import numpy as np + +from sleepecg import SleepRecord, extract_features + +# generate dummy data +recording_hours = 8 +heartbeat_times = np.cumsum(np.random.uniform(0.5, 1.5, recording_hours * 3600)) +sleep_stages = np.random.randint(1, 6, int(max(heartbeat_times)) // 30) +sleep_stage_duration = 30 + +record = SleepRecord( + sleep_stages=sleep_stages, + sleep_stage_duration=sleep_stage_duration, + heartbeat_times=heartbeat_times, +) + +features, stages, feature_ids = extract_features( + [record], + lookback=240, + lookforward=60, + feature_selection=["hrv-time", "LF_norm", "HF_norm", "LF_HF_ratio"], +) +X = features[0] +``` + + +## Using a built-in classifier + +```python +import matplotlib.pyplot as plt + +from sleepecg import load_classifier, plot_hypnogram, read_slpdb, stage + +# the model was built with tensorflow 2.7, running on higher versions might create warnings +# but should not influence the results +clf = load_classifier("ws-gru-mesa", "SleepECG") + +# load record +# `ws-gru-mesa` performs poorly for most SLPDB records, but it works well for slp03 +rec = next(read_slpdb("slp03")) + +# predict stages and plot hypnogram +stages_pred = stage(clf, rec, return_mode="prob") + +plot_hypnogram( + rec, + stages_pred, + stages_mode=clf.stages_mode, + merge_annotations=True, +) + +plt.show() +``` diff --git a/examples/README.md b/examples/README.md index 8e38eb23..1ddfacef 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,12 +1,5 @@ # Examples -To run the provided examples, download or clone the [GitHub Repository](https://github.com/cbrnr/sleepecg) and execute the scripts in this directory or its subdirectories. -- Heartbeat detection demo: - ``` - sleepecg/examples> python heartbeat_detection.py - ``` +This folder contains scripts to generate the built-in classifiers as well as to reproduce the benchmark (more info [here](https://github.com/cbrnr/sleepecg/tree/main/examples/benchmark)). -- Benchmark heartbeat detector runtime (more info [here](https://github.com/cbrnr/sleepecg/tree/main/examples/benchmark)): - ``` - sleepecg/examples/benchmark> python benchmark_detectors.py runtime - ``` +Usage examples can be found in the documentation. diff --git a/examples/feature_extraction.py b/examples/feature_extraction.py deleted file mode 100644 index da1bf364..00000000 --- a/examples/feature_extraction.py +++ /dev/null @@ -1,24 +0,0 @@ -# %% -import numpy as np - -from sleepecg import SleepRecord, extract_features - -# Generate dummy data while we don't have reader functions for sleep data -recording_hours = 8 -heartbeat_times = np.cumsum(np.random.uniform(0.5, 1.5, recording_hours * 3600)) -sleep_stages = np.random.randint(1, 6, int(max(heartbeat_times)) // 30) -sleep_stage_duration = 30 - -rec = SleepRecord( - sleep_stages=sleep_stages, - sleep_stage_duration=sleep_stage_duration, - heartbeat_times=heartbeat_times, -) - -features, stages, feature_ids = extract_features( - [rec], - lookback=240, - lookforward=60, - feature_selection=["hrv-time", "LF_norm", "HF_norm", "LF_HF_ratio"], -) -X = features[0] diff --git a/examples/heartbeat_detection.py b/examples/heartbeat_detection.py deleted file mode 100644 index 1ed46664..00000000 --- a/examples/heartbeat_detection.py +++ /dev/null @@ -1,78 +0,0 @@ -# %% Imports -import matplotlib.pyplot as plt -import numpy as np - -from sleepecg import compare_heartbeats, detect_heartbeats, read_mitdb - -# %% Download and read data, run detector -record = list(read_mitdb(records_pattern="234"))[1] -detection = detect_heartbeats(record.ecg, record.fs) - - -# %% Evaluation and visualization -TP, FP, FN = compare_heartbeats(detection, record.annotation, int(record.fs / 10)) - -t = np.arange(len(record.ecg)) / record.fs - -fig, ax = plt.subplots(3, sharex=True, figsize=(10, 8)) - -ax[0].plot(t, record.ecg, color="k", zorder=1, label="ECG") -ax[0].scatter( - record.annotation / record.fs, - record.ecg[record.annotation], - marker="o", - color="g", - s=50, - zorder=2, - label="annotation", -) -ax[0].set_ylabel("raw signal in mV") - -ax[1].eventplot( - detection / record.fs, - linelength=0.5, - linewidth=0.5, - color="k", - zorder=1, - label="detection", -) -ax[1].scatter( - FN / record.fs, - np.ones_like(FN), - marker="x", - color="r", - s=70, - zorder=2, - label="FN", -) -ax[1].scatter( - FP / record.fs, - np.ones_like(FP), - marker="+", - color="orange", - s=70, - zorder=2, - label="FP", -) -ax[1].set_yticks([]) -ax[1].set_ylabel("heartbeat events") - -ax[2].plot( - detection[1:] / record.fs, - 60 / (np.diff(detection) / record.fs), - label="heartrate in bpm", -) -ax[2].set_ylabel("beats per minute") -ax[2].set_xlabel("time in seconds") - -for ax_ in ax.flat: - ax_.legend(loc="upper right") - ax_.grid(axis="x") - -fig.suptitle( - f"Record ID: {record.id}, lead: {record.lead}\n" - + f"Recall: {len(TP) / (len(TP) + len(FN)):.2%}, " - + f"Precision: {len(TP) / (len(TP) + len(FP)):.2%}", -) - -plt.show() diff --git a/examples/try_ws_gru_mesa.py b/examples/try_ws_gru_mesa.py deleted file mode 100644 index 61a4262c..00000000 --- a/examples/try_ws_gru_mesa.py +++ /dev/null @@ -1,24 +0,0 @@ -# %% -import matplotlib.pyplot as plt - -from sleepecg import load_classifier, plot_hypnogram, read_slpdb, stage - -# The model was built using tensorflow 2.7, running on higher versions might create warnings -# but should not influence the results. -clf = load_classifier("ws-gru-mesa", "SleepECG") - -# %% Load record -# `ws-gru-mesa` performs poorly for most SLPDB records. It does however work well for slp03. -rec = next(read_slpdb("slp03")) - -# %% Predict stages and plot hypnogram -stages_pred = stage(clf, rec, return_mode="prob") - -plot_hypnogram( - rec, - stages_pred, - stages_mode=clf.stages_mode, - merge_annotations=True, -) - -plt.show() diff --git a/mkdocs.yml b/mkdocs.yml index faad3043..dd6792ef 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,6 +11,7 @@ nav: - Classification: classification.md - Plotting: plot.md - Configuration: configuration.md + - Examples: examples.md - API: - Datasets: api/datasets.md - Heartbeat detection: api/heartbeat_detection.md @@ -41,6 +42,7 @@ theme: name: Switch to dark theme features: - toc.integrate + - content.code.copy plugins: - search - mkdocstrings: diff --git a/sleepecg/test/test_examples.py b/sleepecg/test/test_examples.py deleted file mode 100644 index 37e68499..00000000 --- a/sleepecg/test/test_examples.py +++ /dev/null @@ -1,33 +0,0 @@ -# © SleepECG developers -# -# License: BSD (3-clause) - -"""Tests to make sure examples don't crash.""" - -import fnmatch -import runpy -from pathlib import Path - -import matplotlib.pyplot as plt -import pytest - -EXCLUDE = [ - "*/benchmark/*", - "*/classifiers/*", - "*/try_ws_gru_mesa.py", -] - -examples_dir = (Path(__file__).parent / "../../examples").resolve() - -example_files = {str(f) for f in examples_dir.rglob("*.py")} -for pattern in EXCLUDE: - example_files -= set(fnmatch.filter(example_files, pattern)) - - -@pytest.mark.parametrize("script", example_files) -def test_example(script, monkeypatch): - """Run all examples to make sure they don't crash.""" - # Keep matplotlib from showing figures - monkeypatch.setattr(plt, "show", lambda: None) - - runpy.run_path(script)