diff --git a/README.md b/README.md index 3a120144..f47a837d 100644 --- a/README.md +++ b/README.md @@ -154,11 +154,19 @@ and was adapted to python by [Giuseppe Ferraro](mailto:giuseppe.ferraro@isae-sup 2022 Sep 27;22(19):7314. https://doi.org/10.3390/s22197314. ``` -### 6. Real-Time Phase Estimation +### 6. Phase Estimation -This code is based on the Matlab implementation from [Michael Rosenblum](http://www.stat.physik.uni-potsdam.de/~mros), and its corresponding paper [1]. +The oscillator code is based on the Matlab implementation from [Michael +Rosenblum](http://www.stat.physik.uni-potsdam.de/~mros), and its corresponding +paper [1]. The Endpoint Corrected Hilbert Transform (ECHT) method was adapted +from [2]. ```sql -[1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of phase and amplitude with application to neural data. Sci Rep 11, 18037 (2021). https://doi.org/10.1038/s41598-021-97560-5 +[1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of phase + and amplitude with application to neural data. Sci Rep 11, 18037 (2021). + https://doi.org/10.1038/s41598-021-97560-5 +[2] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre, A., + ... & Grossman, N. (2021). Non-invasive suppression of essential tremor via + phase-locked disruption of its temporal coherence. Nature communications, 12(1), 363. ``` diff --git a/doc/conf.py b/doc/conf.py index 37568fef..9b4c3d02 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,7 +26,7 @@ # -- Project information ----------------------------------------------------- project = "MEEGkit" -copyright = "2023, Nicolas Barascud" +copyright = "2024, Nicolas Barascud" author = "Nicolas Barascud" release = meegkit.__version__ version = meegkit.__version__ @@ -63,7 +63,7 @@ "show-inheritance": True, "exclude-members": "__weakref__" } -numpydoc_show_class_members = True +numpydoc_show_class_members = False # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: @@ -129,3 +129,5 @@ "ignore_pattern": "config.py", "run_stale_examples": False, } + +suppress_warnings = ["config.cache"] \ No newline at end of file diff --git a/doc/index.rst b/doc/index.rst index fd2fe078..34717b6e 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -46,7 +46,6 @@ Here is a list of the methods and techniques available in ``meegkit``: ~meegkit.tspca ~meegkit.utils - Examples gallery ---------------- diff --git a/doc/modules/meegkit.phase.rst b/doc/modules/meegkit.phase.rst new file mode 100644 index 00000000..4a364f83 --- /dev/null +++ b/doc/modules/meegkit.phase.rst @@ -0,0 +1,23 @@ +meegkit.phase +============= + +.. automodule:: meegkit.phase + + .. rubric:: Classes + + .. autosummary:: + + NonResOscillator + ResOscillator + Device + ECHT + + .. rubric:: Functions + + .. autosummary:: + + locking_based_phase + rk + init_coefs + one_step_oscillator + one_step_integrator diff --git a/doc/modules/meegkit.utils.rst b/doc/modules/meegkit.utils.rst index 9797de75..661ee49c 100644 --- a/doc/modules/meegkit.utils.rst +++ b/doc/modules/meegkit.utils.rst @@ -6,11 +6,14 @@ meegkit.utils .. autosummary:: auditory + buffer + coherence covariances denoise matrix sig stats + trca | @@ -23,6 +26,28 @@ Auditory .. autosummary:: +| + +---- + +Buffer +------ +.. automodule:: meegkit.utils.buffer + + .. autosummary:: + + +| + +---- + +Coherence +--------- +.. automodule:: meegkit.utils.coherence + + .. autosummary:: + + | ---- diff --git a/doc/sg_execution_times.rst b/doc/sg_execution_times.rst new file mode 100644 index 00000000..b6f64701 --- /dev/null +++ b/doc/sg_execution_times.rst @@ -0,0 +1,73 @@ + +:orphan: + +.. _sphx_glr_sg_execution_times: + + +Computation times +================= +**00:29.509** total execution time for 13 files **from all galleries**: + +.. container:: + + .. raw:: html + + + + + + + + .. list-table:: + :header-rows: 1 + :class: table table-striped sg-datatable + + * - Example + - Time + - Mem (MB) + * - :ref:`sphx_glr_auto_examples_example_trca.py` (``../examples/example_trca.py``) + - 00:11.533 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_dss_line.py` (``../examples/example_dss_line.py``) + - 00:07.115 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_phase_estimation.py` (``../examples/example_phase_estimation.py``) + - 00:07.064 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_mcca.py` (``../examples/example_mcca.py``) + - 00:01.182 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_mcca_2.py` (``../examples/example_mcca_2.py``) + - 00:00.535 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_asr.py` (``../examples/example_asr.py``) + - 00:00.512 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_ress.py` (``../examples/example_ress.py``) + - 00:00.459 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_detrend.py` (``../examples/example_detrend.py``) + - 00:00.259 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_star_dss.py` (``../examples/example_star_dss.py``) + - 00:00.246 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_star.py` (``../examples/example_star.py``) + - 00:00.197 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_dss.py` (``../examples/example_dss.py``) + - 00:00.146 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_echt.py` (``../examples/example_echt.py``) + - 00:00.132 + - 0.0 + * - :ref:`sphx_glr_auto_examples_example_dering.py` (``../examples/example_dering.py``) + - 00:00.129 + - 0.0 diff --git a/examples/example_echt.ipynb b/examples/example_echt.ipynb new file mode 100644 index 00000000..156e0807 --- /dev/null +++ b/examples/example_echt.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Endpoint-corrected Hilbert transform (ECHT) phase estimation\n\nThis example shows how to causally estimate the phase of a signal using\n\nUses `meegkit.phase.ECHT()`.\n\n## References\n.. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre,\n A., ... & Grossman, N. (2021). Non-invasive suppression of essential tremor\n via phase-locked disruption of its temporal coherence. Nature\n communications, 12(1), 363.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.signal import hilbert\n\nfrom meegkit.phase import ECHT\n\nrng = np.random.default_rng(38872)\n\nplt.rcParams[\"axes.grid\"] = True\nplt.rcParams[\"grid.linestyle\"] = \":\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build data\nFirst, we generate a multi-component signal with amplitude and phase\nmodulations, as described in the paper [1]_.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "f0 = 2\n\nN = 500\nsfreq = 100\ntime = np.linspace(0, N / sfreq, N)\nX = np.cos(2 * np.pi * f0 * time - np.pi / 4)\nphase_true = np.angle(hilbert(X))\nX += rng.normal(0, 0.5, N) # Add noise" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compute phase and amplitude\nWe compute the Hilbert phase, as well as the phase obtained with the ECHT\nfilter.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "phase_hilbert = np.angle(hilbert(X)) # Hilbert phase\n\n# Compute ECHT-filtered signal\nfilt_BW = f0 / 2\nl_freq = f0 - filt_BW / 2\nh_freq = f0 + filt_BW / 2\necht = ECHT(l_freq, h_freq, sfreq)\n\nXf = echt.fit_transform(X)\nphase_echt = np.angle(Xf)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize signal\nPlot the results\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 1, figsize=(8, 6))\nax[0].plot(time, X)\nax[0].set_xlabel(\"Time (s)\")\nax[0].set_title(\"Test signal\")\nax[0].set_ylabel(\"Amplitude\")\nax[1].psd(X, Fs=sfreq, NFFT=2048*4, noverlap=sfreq)\nax[1].set_ylabel(\"PSD (dB/Hz)\")\nax[1].set_title(\"Test signal's Fourier spectrum\")\nax[2].plot(time, phase_true, label=\"True phase\", ls=\":\")\nax[2].plot(time, phase_echt, label=\"ECHT phase\", lw=.5, alpha=.8)\nax[2].plot(time, phase_hilbert, label=\"Hilbert phase\", lw=.5, alpha=.8)\nax[2].set_title(\"Phase\")\nax[2].set_ylabel(\"Amplitude\")\nax[2].set_xlabel(\"Time (s)\")\nax[2].legend(loc=\"upper right\", fontsize=\"small\")\nplt.tight_layout()\nplt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/example_echt.py b/examples/example_echt.py new file mode 100644 index 00000000..3ed111cf --- /dev/null +++ b/examples/example_echt.py @@ -0,0 +1,81 @@ +""" +Endpoint-corrected Hilbert transform (ECHT) phase estimation +============================================================ + +This example shows how to causally estimate the phase of a signal using the +Endpoint-corrected Hilbert transform (ECHT) [1]_. + +Uses `meegkit.phase.ECHT()`. + +References +---------- +.. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre, + A., ... & Grossman, N. (2021). Non-invasive suppression of essential tremor + via phase-locked disruption of its temporal coherence. Nature + communications, 12(1), 363. + +""" +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import hilbert + +from meegkit.phase import ECHT + +rng = np.random.default_rng(38872) + +plt.rcParams["axes.grid"] = True +plt.rcParams["grid.linestyle"] = ":" + +############################################################################### +# Build data +# ----------------------------------------------------------------------------- +# First, we generate a multi-component signal with amplitude and phase +# modulations, as described in the paper [1]_. +f0 = 2 + +N = 500 +sfreq = 100 +time = np.linspace(0, N / sfreq, N) +X = np.cos(2 * np.pi * f0 * time - np.pi / 4) +phase_true = np.angle(hilbert(X)) +X += rng.normal(0, 0.5, N) # Add noise + +############################################################################### +# Compute phase and amplitude +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# We compute the Hilbert phase, as well as the phase obtained with the ECHT +# filter. +phase_hilbert = np.angle(hilbert(X)) # Hilbert phase + +# Compute ECHT-filtered signal +filt_BW = f0 / 2 +l_freq = f0 - filt_BW / 2 +h_freq = f0 + filt_BW / 2 +echt = ECHT(l_freq, h_freq, sfreq) + +Xf = echt.fit_transform(X) +phase_echt = np.angle(Xf) + +############################################################################### +# Visualize signal +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Here we plot the original signal, its Fourier spectrum, and the phase obtained +# with the Hilbert transform and the ECHT filter. The ECHT filter provides a +# much smoother phase estimate than the Hilbert transform +fig, ax = plt.subplots(3, 1, figsize=(8, 6)) +ax[0].plot(time, X) +ax[0].set_xlabel("Time (s)") +ax[0].set_title("Test signal") +ax[0].set_ylabel("Amplitude") +ax[1].psd(X, Fs=sfreq, NFFT=2048*4, noverlap=sfreq) +ax[1].set_ylabel("PSD (dB/Hz)") +ax[1].set_title("Test signal's Fourier spectrum") +ax[2].plot(time, phase_true, label="True phase", ls=":") +ax[2].plot(time, phase_echt, label="ECHT phase", lw=.5, alpha=.8) +ax[2].plot(time, phase_hilbert, label="Hilbert phase", lw=.5, alpha=.8) +ax[2].set_title("Phase") +ax[2].set_ylabel("Amplitude") +ax[2].set_xlabel("Time (s)") +ax[2].legend(loc="upper right", fontsize="small") +plt.tight_layout() +plt.show() diff --git a/examples/example_phase_estimation.ipynb b/examples/example_phase_estimation.ipynb index 43d8cc99..605a1912 100644 --- a/examples/example_phase_estimation.ipynb +++ b/examples/example_phase_estimation.ipynb @@ -40,7 +40,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Vizualize signal\nPlot the test signal's Fourier spectrum\n\n" + "### Visualize signal\nPlot the test signal's Fourier spectrum\n\n" ] }, { @@ -83,7 +83,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The first row shows the test signal $s$ and its Hilbert amplitude $a_H$ ; one\ncan see that ah does not represent a good envelope for $s$. On the contrary,\nthe Hilbert-based phase estimation yields good results, and therefore we take\nit for the ground truth.\nRows 2-4 show the difference between the Hilbert phase and causally\nestimated phases ($\\phi_L$, $\\phi_N$, $\\phi_R$) are obtained by means of the\nlocking-based technique, non-resonant and resonant oscillator, respectively).\nThese panels demonstrate that the output of the developed causal algorithms\nis very close to the HT-phase. Notice that we show $\\phi_H - \\phi_N$\nmodulo $2\\pi$, since the phase difference is not bounded.\n\n" + "The first row shows the test signal $s$ and its Hilbert amplitude\n$a_H$ ; one can see that ah does not represent a good envelope for\n$s$. On the contrary, the Hilbert-based phase estimation yields good\nresults, and therefore we take it for the ground truth. Rows 2-4 show the\ndifference between the Hilbert phase and causally estimated phases\n($\\phi_L$, $\\phi_N$, $\\phi_R$) are obtained by means of the\nlocking-based technique, non-resonant and resonant oscillator, respectively).\nThese panels demonstrate that the output of the developed causal algorithms\nis very close to the HT-phase. Notice that we show $\\phi_H - \\phi_N$\nmodulo :math:`2\\pi, since the phase difference is not bounded.\n\n" ] }, { @@ -94,7 +94,7 @@ }, "outputs": [], "source": [ - "f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8))\nax[0, 0].plot(time, s, time, ht_phase, lw=.75)\nax[0, 0].set_ylabel(r\"$s,\\phi_H$\")\nax[0, 0].set_title(\"Signal and its Hilbert phase\")\n\nax[1, 0].plot(time, lb_phi_dif, lw=.75)\nax[1, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[1, 0].set_ylabel(r\"$\\phi_H - \\phi_L$\")\nax[1, 0].set_ylim([-np.pi, np.pi])\nax[1, 0].set_title(\"Phase locking approach\")\n\nax[2, 0].plot(time, nr_phi_dif, lw=.75)\nax[2, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[2, 0].set_ylabel(r\"$\\phi_H - \\phi_N$\")\nax[2, 0].set_ylim([-np.pi, np.pi])\nax[2, 0].set_title(\"Nonresonant oscillator\")\n\nax[3, 0].plot(time, r_phi_dif, lw=.75)\nax[3, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[3, 0].set_ylim([-np.pi, np.pi])\nax[3, 0].set_ylabel(\"$\\phi_H - \\phi_R$\")\nax[3, 0].set_xlabel(\"Time\")\nax[3, 0].set_title(\"Resonant oscillator\")\n\nax[0, 1].plot(time, s, time, ht_ampl, lw=.75)\nax[0, 1].set_ylabel(r\"$s,a_H$\")\nax[0, 1].set_title(\"Signal and its Hilbert amplitude\")\n\nax[1, 1].axis(\"off\")\n\nax[2, 1].plot(time, s, time, nr_ampl, lw=.75)\nax[2, 1].set_ylabel(r\"$s,a_N$\")\nax[2, 1].set_title(\"Amplitudes\")\nax[2, 1].set_title(\"Nonresonant oscillator\")\n\nax[3, 1].plot(time, s, time, r_ampl, lw=.75)\nax[3, 1].set_xlabel(\"Time\")\nax[3, 1].set_ylabel(r\"$s,a_R$\")\nax[3, 1].set_title(\"Resonant oscillator\")\nplt.suptitle(\"Amplitude (right) and phase (left) estimation algorithms\")\nplt.tight_layout()\nplt.show()" + "f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8))\nax[0, 0].plot(time, s, time, ht_phase, lw=.75)\nax[0, 0].set_ylabel(r\"$s,\\phi_H$\")\nax[0, 0].set_title(\"Signal and its Hilbert phase\")\n\nax[1, 0].plot(time, lb_phi_dif, lw=.75)\nax[1, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[1, 0].set_ylabel(r\"$\\phi_H - \\phi_L$\")\nax[1, 0].set_ylim([-np.pi, np.pi])\nax[1, 0].set_title(\"Phase locking approach\")\n\nax[2, 0].plot(time, nr_phi_dif, lw=.75)\nax[2, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[2, 0].set_ylabel(r\"$\\phi_H - \\phi_N$\")\nax[2, 0].set_ylim([-np.pi, np.pi])\nax[2, 0].set_title(\"Nonresonant oscillator\")\n\nax[3, 0].plot(time, r_phi_dif, lw=.75)\nax[3, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[3, 0].set_ylim([-np.pi, np.pi])\nax[3, 0].set_ylabel(r\"$\\phi_H - \\phi_R$\")\nax[3, 0].set_xlabel(\"Time\")\nax[3, 0].set_title(\"Resonant oscillator\")\n\nax[0, 1].plot(time, s, time, ht_ampl, lw=.75)\nax[0, 1].set_ylabel(r\"$s,a_H$\")\nax[0, 1].set_title(\"Signal and its Hilbert amplitude\")\n\nax[1, 1].axis(\"off\")\n\nax[2, 1].plot(time, s, time, nr_ampl, lw=.75)\nax[2, 1].set_ylabel(r\"$s,a_N$\")\nax[2, 1].set_title(\"Amplitudes\")\nax[2, 1].set_title(\"Nonresonant oscillator\")\n\nax[3, 1].plot(time, s, time, r_ampl, lw=.75)\nax[3, 1].set_xlabel(\"Time\")\nax[3, 1].set_ylabel(r\"$s,a_R$\")\nax[3, 1].set_title(\"Resonant oscillator\")\nplt.suptitle(\"Amplitude (right) and phase (left) estimation algorithms\")\nplt.tight_layout()\nplt.show()" ] } ], @@ -114,7 +114,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/examples/example_phase_estimation.py b/examples/example_phase_estimation.py index 8e5e19b7..b7ac9230 100644 --- a/examples/example_phase_estimation.py +++ b/examples/example_phase_estimation.py @@ -43,7 +43,7 @@ time = np.arange(npt) * dt ############################################################################### -# Vizualize signal +# Visualize signal # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Plot the test signal's Fourier spectrum f, ax = plt.subplots(2, 1) @@ -83,16 +83,16 @@ # Here we reproduce figure 1 from the original paper [1]_. ############################################################################### -# The first row shows the test signal $s$ and its Hilbert amplitude $a_H$ ; one -# can see that ah does not represent a good envelope for $s$. On the contrary, -# the Hilbert-based phase estimation yields good results, and therefore we take -# it for the ground truth. -# Rows 2-4 show the difference between the Hilbert phase and causally -# estimated phases ($\phi_L$, $\phi_N$, $\phi_R$) are obtained by means of the +# The first row shows the test signal :math:`s` and its Hilbert amplitude +# :math:`a_H` ; one can see that ah does not represent a good envelope for +# :math:`s`. On the contrary, the Hilbert-based phase estimation yields good +# results, and therefore we take it for the ground truth. Rows 2-4 show the +# difference between the Hilbert phase and causally estimated phases +# (:math:`\phi_L`, :math:`\phi_N`, :math:`\phi_R`) are obtained by means of the # locking-based technique, non-resonant and resonant oscillator, respectively). # These panels demonstrate that the output of the developed causal algorithms -# is very close to the HT-phase. Notice that we show $\phi_H - \phi_N$ -# modulo $2\pi$, since the phase difference is not bounded. +# is very close to the HT-phase. Notice that we show :math:`\phi_H - \phi_N` +# modulo :math:`2\pi`, since the phase difference is not bounded. f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8)) ax[0, 0].plot(time, s, time, ht_phase, lw=.75) ax[0, 0].set_ylabel(r"$s,\phi_H$") @@ -113,7 +113,7 @@ ax[3, 0].plot(time, r_phi_dif, lw=.75) ax[3, 0].axhline(0, color="k", ls=":", zorder=-1) ax[3, 0].set_ylim([-np.pi, np.pi]) -ax[3, 0].set_ylabel("$\phi_H - \phi_R$") +ax[3, 0].set_ylabel(r"$\phi_H - \phi_R$") ax[3, 0].set_xlabel("Time") ax[3, 0].set_title("Resonant oscillator") diff --git a/meegkit/asr.py b/meegkit/asr.py index 9d6ad27d..29ce33ca 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -551,7 +551,7 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method="riemann", detrend : bool If True, detrend filtered data (default=False). method : {'euclid', 'riemann'} - Metric to compute the covariance matric average. + Metric to compute the covariance matrix average. Returns ------- diff --git a/meegkit/dss.py b/meegkit/dss.py index 8105570b..65ce19ee 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -48,14 +48,11 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12): Power per component (averaged). """ - n_trials = theshapeof(X)[-1] - # if demean: # remove weighted mean # X = demean(X, weights) # weighted mean over trials (--> bias function for DSS) xx, ww = mean_over_trials(X, weights) - ww /= n_trials # covariance of raw and biased X c0, nc0 = tscov(X, None, weights) diff --git a/meegkit/phase.py b/meegkit/phase.py index 8e3cc87f..39e19b0b 100644 --- a/meegkit/phase.py +++ b/meegkit/phase.py @@ -9,11 +9,13 @@ be slow for large input arrays (n_channels >> 10), since an individual oscillator is instantiated for each channel. -.. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation - of phase and amplitude with application to neural data. Sci Rep 11, 18037 +.. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of + phase and amplitude with application to neural data. Sci Rep 11, 18037 (2021). https://doi.org/10.1038/s41598-021-97560-5 """ import numpy as np +from scipy.fftpack import fft, fftshift, ifft, ifftshift, next_fast_len +from scipy.signal import butter, freqz from meegkit.utils.buffer import Buffer @@ -104,13 +106,14 @@ def step(self, sprev, s, snew): class NonResOscillator: """Real-time measurement of phase and amplitude using non-resonant oscillator. - This estimator relies on the resonance effect. The measuring “device” consists - of two linear damped oscillators. The oscillators' frequency is much larger - than the frequency of the signal, i.e., the system is far from resonance. - We choose the damping parameters to ensure that (i) the phase of the first - linear oscillator equals that of the input and that (ii) amplitude of the - second one and the input relate by a known constant multiplicator. The - technique yields both phase and amplitude of the input signal. + This estimator relies on the resonance effect. The measuring “device” + consists of two linear damped oscillators. The oscillators' frequency is + much larger than the frequency of the signal, i.e., the system is far from + resonance. We choose the damping parameters to ensure that (i) the phase of + the first linear oscillator equals that of the input and that (ii) + amplitude of the second one and the input relate by a known constant + multiplicator. The technique yields both phase and amplitude of the input + signal. This estimator includes an automated frequency-tuning algorithm to adjust to the a priori unknown signal frequency. @@ -127,24 +130,23 @@ class NonResOscillator: References ---------- .. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation - of phase and amplitude with application to neural data. Sci Rep 11, 18037 - (2021). https://doi.org/10.1038/s41598-021-97560-5 + of phase and amplitude with application to neural data. Sci Rep 11, + 18037 (2021). https://doi.org/10.1038/s41598-021-97560-5 """ - def __init__(self, fs=250, nu=1.1): + def __init__(self, fs=250, nu=1.1, alpha_a=6.0, alpha_p=0.2, update_factor=5): # Parameters of the measurement "devices" self.dt = 1 / fs # Sampling interval self.nu = nu # Rough estimate of the tremor frequency self.om0 = 5 * nu # Oscillator frequency (estimation) - self.alpha_a = 6.0 # Damping parameter for the "amplitude device" - self.gamma_a = self.alpha_a / 2 - self.alpha_p = 0.2 # Damping parameter for the "phase device" - self.gamma_p = self.alpha_p / 2 + self.alpha_a = alpha_a # Damping parameter for the "amplitude device" + self.gamma_a = alpha_a / 2 + self.alpha_p = alpha_p # Damping parameter for the "phase device" + self.gamma_p = alpha_p / 2 self.factor = np.sqrt((self.om0 ** 2 - nu ** 2) ** 2 + (self.alpha_a * nu) ** 2) # Update parameters, and precomputed quantities - update_factor = 5 self.memory = round(2 * np.pi / self.om0 / self.dt) self.update_point = 2 * self.memory self.update_step = round(self.memory / update_factor) @@ -196,7 +198,7 @@ def transform(self, X): continue # Skip the first two samples # Amplitude estimation - spp, sp, s = self.buffer.view(3) + spp, sp, s = self.buffer.view(3)[:, 0] for ch in range(self.n_channels): self.adevice[ch].step(spp, sp, s) @@ -261,11 +263,12 @@ class ResOscillator: (2021). https://doi.org/10.1038/s41598-021-97560-5 """ - def __init__(self, fs=1000, nu=4.5, freq_adaptation=True): + def __init__(self, fs=1000, nu=4.5, update_factor=5, freq_adaptation=True, + assume_centered=False): # Parameters of the measurement "device" self.dt = 1 / fs # Sampling interval - self.om0 = 1.1 # Angular frequency + self.om0 = nu # Angular frequency self.alpha = 0.3 * self.om0 self.gamma = self.alpha / 2 @@ -273,7 +276,7 @@ def __init__(self, fs=1000, nu=4.5, freq_adaptation=True): nperiods = 1 # Number of previous periods for frequency correction npt_period = round(2 * np.pi / self.om0 / self.dt) # Number of points per period self.memory = nperiods * npt_period # M points for frequency correction buffer - self.update_factor = 5 # Number of frequency updates per period + self.update_factor = update_factor # Number of frequency updates per period self.update_step = round(npt_period / self.update_factor) self.updatepoint = 2 * self.memory @@ -287,6 +290,7 @@ def __init__(self, fs=1000, nu=4.5, freq_adaptation=True): self.buffer = None self.runav = 0. # Initial guess for the dc-component self.freq_adaptation = freq_adaptation + self.assume_centered = assume_centered def _set_devices(self, n_channels): # Set up the phase and amplitude "devices" @@ -362,7 +366,9 @@ def transform(self, X): self.osc[ch].init_coefs(om0, self.dt, self.gamma) # Update running average - self.runav = np.mean(self.buffer.view(self.memory), axis=0, keepdims=True) + if self.assume_centered is False: + self.runav = np.mean(self.buffer.view(self.memory), axis=0, + keepdims=True) self.updatepoint += self.update_step # Point for the next update return phase, ampl @@ -619,3 +625,167 @@ def one_step_integrator(z, edelmu, mu, dt, spp, sp, s): C0 = z + d z = C0 * edelmu - d + b * dt - 2 * c * mu * dt + c * dt ** 2 return z + + +class ECHT: + """Endpoint Corrected Hilbert Transform (ECHT). + + See [1]_ for details. + + Parameters + ---------- + X : ndarray, shape=(n_samples, n_channels) + Time domain signal. + l_freq : float | None + Low-cutoff frequency of a bandpass causal filter. If None, the data is + only low-passed. + h_freq : float | None + High-cutoff frequency of a bandpass causal filter. If None, the data is + only high-passed. + sfreq : float + Sampling rate of time domain signal. + n_fft : int, optional + Length of analytic signal. If None, it defaults to the length of X. + filt_order : int, optional + Order of the filter. Default is 2. + + Notes + ----- + One common implementation of the Hilbert Transform uses a DFT (aka FFT) + as part of its computation. Inherent to the DFT is the assumption that + a finite sample of a signal is replicated infinitely in time, effectively + abutting the end of a sample with its replicated start. If the start and + end of the sample are not continuous with each other, distortions are + introduced by the DFT. Echt effectively smooths out this 'discontinuity' + by selectively deforming the start of the sample. It is hence most suited + for real-time applications in which the point/s of interest is/are the + most recent one/s (i.e. last) in the sample window. + + We found that a filter bandwidth (BW=h_freq-l_freq) of up to half the + signal's central frequency works well. + + References + ---------- + .. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., + Latorre, A., ... & Grossman, N. (2021). Non-invasive suppression of + essential tremor via phase-locked disruption of its temporal coherence. + Nature communications, 12(1), 363. + + Examples + -------- + >>> f0 = 2 + >>> filt_BW = f0 / 2 + >>> N = 1000 + >>> sfreq = N / (2 * np.pi) + >>> t = np.arange(-2 * np.pi, 0, 1 / sfreq) + >>> X = np.cos(2 * np.pi * f0 * t - np.pi / 4) + >>> l_freq = f0 - filt_BW / 2 + >>> h_freq = f0 + filt_BW / 2 + >>> Xf = echt(X, l_freq, h_freq, sfreq) + """ + + def __init__(self, l_freq, h_freq, sfreq, n_fft=None, filt_order=2): + self.l_freq = l_freq + self.h_freq = h_freq + self.sfreq = sfreq + self.n_fft = n_fft + self.filt_order = filt_order + + # attributes + self.h_ = None + self.coef_ = None + + def fit(self, X, y=None): + """Fit the ECHT transform to the input signal. + + Parameters + ---------- + X : ndarray, shape=(n_samples, n_channels) + The input signal to be transformed. + + """ + if self.n_fft is None: + self.n_fft = next_fast_len(X.shape[0]) + + # Set the amplitude of the negative components of the FFT to zero and + # multiply the amplitudes of the positive components, apart from the + # zero-frequency component (DC) and Nyquist components, by 2. + # + # If the signal has an even number of elements n, the frequency components + # are: + # - n/2-1 negative elements, + # - one DC element, + # - n/2-1 positive elements and + # - one Nyquist element (in order). + # + # If the signal has an odd number of elements n, the frequency components + # are: + # - (n-1)/2 negative elements, + # - one DC element, + # - (n-1)/2 positive elements (in order). + # - no positive element corresponding to the Nyquist frequency. + self.h_ = np.zeros(self.n_fft) + self.h_[0] = 1 + self.h_[1:(self.n_fft // 2) + 1] = 2 + if self.n_fft % 2 == 0: + self.h_[self.n_fft // 2] = 1 + + # The frequency response vector is computed using freqz from the filter's + # impulse response function, computed by butter function, and the user + # defined low-cutoff frequency, high-cutoff frequency and sampling rate. + Wn = [self.l_freq / (self.sfreq / 2), self.h_freq / (self.sfreq / 2)] + b, a = butter(self.filt_order, Wn, btype="band") + T = 1 / self.sfreq * self.n_fft + filt_freq = np.ceil(np.arange(-self.n_fft / 2, self.n_fft / 2) / T) + + self.coef_ = freqz(b, a, filt_freq, fs=self.sfreq)[1] + self.coef_ = self.coef_[:, None] + + return self + + def transform(self, X): + """Apply the ECHT transform to the input signal. + + Parameters + ---------- + X : ndarray, shape=(n_samples, n_channels) + The input signal to be transformed. + + Returns + ------- + Xf : ndarray, shape=(n_samples, n_channels) + The transformed signal (complex-valued). + + """ + if not np.isrealobj(X): + X = np.real(X) + + # if not fitted + if self.h_ is None or self.coef_ is None: + self.fit(X) + + if X.ndim == 1: + X = X[:, np.newaxis] + + # Compute the FFT of the signal. + Xf = fft(X, self.n_fft, axis=0) + + # In contrast to :meth:`scipy.signal.hilbert()`, the code then + # multiplies the array by a frequency response vector of a causal + # bandpass filter. + Xf = Xf * self.h_[:, None] + + # The array is arranged, using fft_shift function, so that the zero-frequency + # component is at the center of the array, before the multiplication, and + # rearranged back so that the zero-frequency component is at the left of the + # array using ifft_shift(). Finally, the IFFT is computed. + Xf = fftshift(Xf) + Xf = Xf * self.coef_ + Xf = ifftshift(Xf) + Xf = ifft(Xf, axis=0) + + return Xf + + def fit_transform(self, X, y=None): + """Fit the ECHT transform to the input signal and transform it.""" + return self.fit(X).transform(X) \ No newline at end of file diff --git a/meegkit/tspca.py b/meegkit/tspca.py index 677aea3b..95f98ec5 100644 --- a/meegkit/tspca.py +++ b/meegkit/tspca.py @@ -34,7 +34,7 @@ def tspca(X, shifts=None, keep=None, threshold=None, weights=None, weights : array Sample weights. demean : bool - If True, Epochs are centered before comuting PCA (default=0). + If True, Epochs are centered before computing PCA (default=0). Returns ------- diff --git a/meegkit/utils/__init__.py b/meegkit/utils/__init__.py index 28b1cc6b..bccc3b75 100644 --- a/meegkit/utils/__init__.py +++ b/meegkit/utils/__init__.py @@ -1,6 +1,16 @@ """Utility functions.""" from .auditory import AuditoryFilterbank, GammatoneFilterbank, erb2hz, erbspace, hz2erb from .base import mldivide, mrdivide +from .coherence import ( + cross_coherence, + plot_polycoherence, + plot_polycoherence_1d, + plot_signal, + polycoherence_0d, + polycoherence_1d, + polycoherence_1d_sum, + polycoherence_2d, +) from .covariances import ( block_covariance, convmtx, diff --git a/meegkit/utils/asr.py b/meegkit/utils/asr.py index dd0fa6b5..cab47ca9 100755 --- a/meegkit/utils/asr.py +++ b/meegkit/utils/asr.py @@ -93,7 +93,7 @@ def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1, # we can generally skip the tail below the lower quantile lower_min = np.min(quants) # maximum width is the fit interval if all data is clean - max_width = np.diff(quants) + max_width = np.diff(quants)[0] # minimum width of the fit interval, as fraction of data min_width = min_clean_fraction * max_width diff --git a/meegkit/utils/buffer.py b/meegkit/utils/buffer.py index 1cb2733f..d91614e4 100644 --- a/meegkit/utils/buffer.py +++ b/meegkit/utils/buffer.py @@ -22,14 +22,9 @@ class Buffer: The number of channels of the buffer. counter : int The number of samples in the buffer. - head : int - The index of the most recent sample in the buffer. - tail : int - The index of the most recent read sample in the buffer. _data : ndarray, shape (size, n_channels) Data buffer. - """ def __init__(self, size, n_channels=1): @@ -87,6 +82,8 @@ def push(self, X): self._head += n_samples self.counter += n_samples + return self + def get_new_samples(self, n_samples=None): """Consume n_samples.""" if n_samples is None: diff --git a/meegkit/utils/coherence.py b/meegkit/utils/coherence.py new file mode 100644 index 00000000..0d13ab98 --- /dev/null +++ b/meegkit/utils/coherence.py @@ -0,0 +1,424 @@ +"""Signal coherence tools. + +Compute 2D, 1D, and 0D bicoherence, polycoherence, bispectrum, and polyspectrum. + +Bicoherence is a measure of the degree of phase coupling between different +frequency components in a signal. It's essentially a normalized form of the +bispectrum, which itself is a Fourier transform of the third-order cumulant of +a time series: + +- 2D bicoherence is the most common form, where one looks at a + two-dimensional representation of the phase coupling between different + frequencies. It's a function of two frequency variables. +- 1D bicoherence would imply a slice or a specific condition in the 2D + bicoherence, reducing it to a function of a single frequency variable. It + simplifies the analysis by looking at the relationship between a particular + frequency and its harmonics or other relationships +- 0D bicoherence would imply a single value representing some average or + overall measure of phase coupling in the signal. It's a highly condensed + summary, which might represent the average bicoherence over all frequency + pairs, for instance. + + +""" +from collections.abc import Iterable + +import matplotlib.pyplot as plt +import numpy as np +from numpy.fft import rfft, rfftfreq +from scipy.fftpack import next_fast_len +from scipy.signal import spectrogram + + +def cross_coherence(x1, x2, sfreq, norm=2, **kwargs): + """Compute the bispectral cross-coherence between two signals of same length. + + Code adapted from [2]_. + + Parameters + ---------- + x1: array-like, shape=([n_channels, ]n_samples) + First signal. + x2: array-like, shape=([n_channels, ]n_samples) + Second signal. + sfreq: float + Sampling sfreq. + norm: int | None + Norm (default=2). If None, return bispectrum. + + Returns + ------- + f1: array-like + Frequency axis. + B: array-like + Bicoherence between s1 and s2. + + References + ---------- + .. [1] http://wiki.fusenet.eu/wiki/Bicoherence + .. [2] https://stackoverflow.com/a/36725871 + + """ + f1, _, S1 = compute_spectrogram(x1, sfreq, **kwargs) + _, _, S2 = compute_spectrogram(x2, sfreq, **kwargs) + + # compute the bicoherence + ind = np.arange(f1.size // 2) + indsum = ind[:, None] + ind[None, :] + + P1 = S1[..., ind, None] + P2 = S2[..., None, ind] + P12 = S1[..., indsum] + + B = np.mean(P1 * P2 * np.conj(P12), axis=-3) + + if norm is not None: # Bispectrum -> Bicoherence + B = norm_spectrum(B, P1, P2, P12, time_axis=-3) + + return f1[ind], B + + +def polycoherence_0d(X, sfreq, freqs, norm=2, synthetic=None, **kwargs): + """Polycoherence between freqs and sum of freqs. + + Parameters + ---------- + X: ndarray, shape=(n_channels, n_samples) + Input data. + sfreq: float + Sampling rate. + freqs: list[float] + Fixed frequencies. + norm: int | None + Norm (default=2). + synthetic: tuple(float, float, float) + Used for synthetic signal for some frequencies (freq, amplitude, + phase), freq must coincide with the first fixed frequency. + **kwargs: dict + Additional parameters passed to scipy.signal.spectrogram. Important + parameters are nperseg, noverlap, nfft. + + Returns + ------- + B: ndarray, shape=(n_channels,) + Polycoherence + + """ + assert isinstance(freqs, Iterable), "freqs must be a list" + freq, t, S = compute_spectrogram(X, sfreq, **kwargs) + + ind = _freq_ind(freq, freqs) + indsum = _freq_ind(freq, np.sum(freqs)) + + Pi = _product_other_freqs(S, ind, synthetic, t) + Psum = S[..., indsum] + + B = np.mean(Pi * np.conj(Psum), axis=-1) + + if norm is not None: + # Bispectrum -> Bicoherence + B = norm_spectrum(B, Pi, 1., Psum, time_axis=-1) + + return B + + +def polycoherence_1d(X, sfreq, f2, norm=2, synthetic=None, **kwargs): + """1D polycoherence as a function of f1 and at least one fixed frequency f2. + + Parameters + ---------- + X: ndarray + Input data. + sfreq: float + Sampling rate + f2: list[float] + Fixed frequencies. + norm: int | None + Norm (default=2). + synthetic: tuple(float, float, float) + Used for synthetic signal for some frequencies (freq, amplitude, + phase), freq must coincide with the first fixed frequency. + **kwargs: + Additional parameters passed to scipy.signal.spectrogram. Important + parameters are `nperseg`, `noverlap`, `nfft`. + + Returns + ------- + freq: ndarray, shape=(n_freqs_f1,) + Frequencies + B: ndarray, shape=(n_channels, n_freqs_f1) + 1D polycoherence. + + """ + assert isinstance(f2, Iterable), "f2 must be a list" + + f1, t, S = compute_spectrogram(X, sfreq, **kwargs, axis=-1) + S = np.require(S, "complex64") + ind2 = _freq_ind(f1, f2) + ind1 = np.arange(len(f1) - sum(ind2)) + indsum = ind1 + sum(ind2) + + P1 = S[..., ind1] + Pother = _product_other_freqs(S, ind2, synthetic, t)[..., None] + Psum = S[..., indsum] + + B = np.nanmean(P1 * Pother * np.conj(Psum), axis=-2) + + if norm is not None: + B = norm_spectrum(B, P1, Pother, Psum, time_axis=-2) + + return f1[ind1], B + + +def polycoherence_1d_sum(X, sfreq, fsum, *ofreqs, norm=2, synthetic=None, **kwargs): + """1D polycoherence with fixed frequency sum fsum as a function of f1. + + Returns polycoherence for frequencies ranging from 0 up to fsum. + + Parameters + ---------- + X: ndarray + Input data. + sfreq: float + Sampling rate. + fsum : float + Fixed frequency sum. + ofreqs: list[float] + Fixed frequencies. + norm: int or None + If 2 - return polycoherence, n0 = n1 = n2 = 2 (default) + synthetic: tuple(float, float, float) | None + Used for synthetic signal for some frequencies (freq, amplitude, + phase), freq must coincide with the first fixed frequency. + + Returns + ------- + freq: ndarray, shape=(n_freqs,) + Frequencies. + B: ndarray, shape=(n_channels, n_freqs) + Polycoherence for f1+f2=fsum. + + """ + freq, t, S = compute_spectrogram(X, sfreq, **kwargs) + + indsum = _freq_ind(freq, fsum) + ind1 = np.arange(np.searchsorted(freq, fsum - np.sum(ofreqs))) + ind3 = _freq_ind(freq, ofreqs) + ind2 = indsum - ind1 - sum(ind3) + + P1 = S[..., ind1] + P2 = S[..., ind2] + Pother = _product_other_freqs(S, ind3, synthetic, t)[..., None] + Psum = S[..., [indsum]] + + B = np.mean(P1 * P2 * Pother * np.conj(Psum), axis=-2) + + if norm is not None: + B = norm_spectrum(B, P1, P2 * Pother, Psum, time_axis=-2) + + return freq[ind1], B + + +def polycoherence_2d(X, sfreq, ofreqs=None, norm=2, flim1=None, flim2=None, + synthetic=None, **kwargs): + """2D polycoherence between freqs and their sum as a function of f1 and f2. + + 2D bicoherence is the most common form, where one looks at a + two-dimensional representation of the phase coupling between different + frequencies. It is a function of two frequency variables. + + 2D polycoherence as a function of f1 and f2, ofreqs are additional fixed + frequencies. + + Parameters + ---------- + X: ndarray + Input data. + sfreq: float + Sampling rate. + ofreqs: list[float] + Fixed frequencies. + norm: int or None + If 2 - return polycoherence (default), else return polyspectrum. + flim1: tuple | None + Frequency limits for f1. If None, it is set to (0, nyquist / 2) + flim2: tuple | None + Frequency limits for f2. + + Returns + ------- + freq1: ndarray, shape=(n_freqs_f1,) + Frequencies for f1. + freq2: ndarray, shape=(n_freqs_f2,) + Frequencies for f2. + B: ndarray, shape=([n_chans, ]n_freqs_f1, n_freqs_f2) + Polycoherence. + + """ + freq, t, S = compute_spectrogram(X, sfreq, **kwargs) + + if ofreqs is None: + ofreqs = [] + if flim1 is None: + flim1 = (0, (np.max(freq) - np.sum(ofreqs)) / 2) + if flim2 is None: + flim2 = (0, (np.max(freq) - np.sum(ofreqs)) / 2) + + # indices ranges matching flim1 and flim2 + ind1 = np.arange(*np.searchsorted(freq, flim1)) + ind2 = np.arange(*np.searchsorted(freq, flim2)) + ind3 = _freq_ind(freq, ofreqs) + indsum = ind1[:, None] + ind2[None, :] + sum(ind3) + + Pother = _product_other_freqs(S, ind3, synthetic, t)[..., None, None] + + P1 = S[..., ind1, None] + P2 = S[..., None, ind2] * Pother + P12 = S[..., indsum] + + # Average over time to get the bispectrum + B = np.nanmean(P1 * P2 * np.conj(P12), axis=-3) + + if norm is not None: # Bispectrum -> Bicoherence + B = norm_spectrum(B, P1, P2, P12, time_axis=-3) + + return freq[ind1], freq[ind2], B + + +def compute_spectrogram(X, sfreq, **kwargs): + """Compute the complex spectrogram of X. + + Simple wrapper around scipy.signal.spectrogram. + + Returns + ------- + f: ndarray, shape=(n_freqs,) + Positive frequencies. + t: ndarray, shape=(n_timesteps,) + Time axis. + S: ndarray, shape=(n_chans, n_timesteps, n_freqs) + Complex spectrogram (one-sided). + """ + N = X.shape[-1] + kwargs.setdefault("nperseg", N // 2) + kwargs.setdefault("noverlap", N // 5) + kwargs.setdefault("nfft", next_fast_len(N)) + + f, t, S = spectrogram(X, fs=sfreq, mode="complex", return_onesided=False, **kwargs) + S = S[:len(f) // 2] # only positive frequencies + f = f[:len(f) // 2] # only positive frequencies + S = np.require(S, "complex64") + S = np.swapaxes(S, -1, -2) # transpose (f, t) -> (t, f) + return f, t, S + + +def norm_spectrum(spec, P1, P2, P12, time_axis=-2): + """Compute bicoherence from bispectrum. + + For formula see [1]_. + + Parameters + ---------- + spec: ndarray, shape=(n_chans, n_freqs) + Polyspectrum. + P1: array-like, shape=(n_chans, n_times, n_freqs_f1) + Spectrum evaluated at f1. + P2: array-like, shape=(n_chans, n_times, n_freqs_f2) + Spectrum evaluated at f2. + P12: array-like, shape=(n_chans, n_times, n_freqs) + Spectrum evaluated at f1 + f2. + time_axis: int + Time axis. + + Returns + ------- + coh: ndarray, shape=(n_chans,) + Polycoherence. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Bicoherence + + """ + coh = np.abs(spec) ** 2 + norm = np.nanmean(np.abs(P1 * P2) ** 2, axis=time_axis) + norm *= np.nanmean(np.abs(np.conj(P12)) ** 2, axis=time_axis) + coh /= norm + coh **= 0.5 + return coh + + +def plot_polycoherence(freq1, freq2, bicoh, ax=None): + """Plot polycoherence.""" + df1 = freq1[1] - freq1[0] # resolution + df2 = freq2[1] - freq2[0] + freq1 = np.append(freq1, freq1[-1] + df1) - 0.5 * df1 + freq2 = np.append(freq2, freq2[-1] + df2) - 0.5 * df2 + + if ax is None: + f, ax = plt.subplots() + + ax.pcolormesh(freq2, freq1, np.abs(bicoh)) + ax.set_xlabel("freq (Hz)") + ax.set_ylabel("freq (Hz)") + # ax.colorbar() + return ax + +def plot_polycoherence_1d(freq, coh): + """Plot polycoherence for fixed frequencies.""" + plt.figure() + plt.plot(freq, coh) + plt.xlabel("freq (Hz)") + + +def plot_signal(t, signal, ax=None): + """Plot signal and spectrum.""" + if ax is None: + f, ax = plt.subplots(2, 1) + + ax[0].plot(t, signal) + ax[0].set_xlabel("time (s)") + + ndata = len(signal) + nfft = next_fast_len(ndata) + freq = rfftfreq(nfft, t[1] - t[0]) + spec = rfft(signal, nfft) * 2 / ndata + ax[1].plot(freq, np.abs(spec)) + ax[1].set_xlabel("freq (Hz)") + return ax + +def _freq_ind(freq, f0): + """Find the index of the frequency closest to f0.""" + if isinstance(f0, Iterable): + return [np.argmin(np.abs(freq - f)) for f in f0] + else: + return np.argmin(np.abs(freq - f0)) + + +def _product_other_freqs(spec, indices, synthetic=None, t=None): + """Product of all frequencies.""" + if synthetic is None: + synthetic = () + + p1 = synthetic_signal(t, synthetic) + p2 = np.prod(spec[..., indices[len(synthetic):]], axis=-1) + return p1 * p2 + + +def synthetic_signal(t, synthetic): + """Create a synthetic complex signal spectrum. + + Parameters + ---------- + t: array-like + Time. + synthetic: list[tuple(float, float, float)] + List of tuples with (freq, amplitude, phase). + + Returns + ------- + Complex signal. + + """ + return np.prod([amp * np.exp(2j * np.pi * freq * t + phase) + for (freq, amp, phase) in synthetic], axis=0) \ No newline at end of file diff --git a/meegkit/utils/covariances.py b/meegkit/utils/covariances.py index b4b8b744..36b107fa 100644 --- a/meegkit/utils/covariances.py +++ b/meegkit/utils/covariances.py @@ -33,10 +33,10 @@ def block_covariance(data, window=128, overlap=0.5, padding=True, estimator="cov Block covariance. """ - from pyriemann.utils.covariance import _check_est + from pyriemann.utils.covariance import check_function, cov_est_functions assert 0 <= overlap < 1, "overlap must be < 1" - est = _check_est(estimator) + est = check_function(estimator, cov_est_functions) cov = [] n_chans, n_samples = data.shape if padding: # pad data with zeros @@ -212,18 +212,18 @@ def tscov(X, shifts=None, weights=None, assume_centered=True): shifts, n_shifts = _check_shifts(shifts) if not assume_centered: - X = X - X.mean(0, keepdims=1) + X = X - X.mean(0, keepdims=True) if weights.any(): # weights X = np.einsum("ijk,ilk->ijk", X, weights) # element-wise mult - tw = np.sum(weights[:]) + tw = np.sum(weights[:]) - 1 else: # no weights N = 0 if len(shifts[shifts < 0]): N -= np.min(shifts) if len(shifts[shifts >= 0]): N += np.max(shifts) - tw = (n_chans * n_shifts - N) * n_trials + tw = n_trials * (n_times - N - 1) C = np.zeros((n_chans * n_shifts, n_chans * n_shifts)) for trial in range(n_trials): diff --git a/meegkit/utils/denoise.py b/meegkit/utils/denoise.py index 4317eabc..c896a998 100644 --- a/meegkit/utils/denoise.py +++ b/meegkit/utils/denoise.py @@ -75,7 +75,7 @@ def mean_over_trials(X, weights=None): if not weights.any(): y = np.mean(X, 2) - tw = np.ones((n_samples, n_chans, 1)) * n_trials + tw = np.ones((n_samples, n_chans, 1)) else: m, n, o = theshapeof(weights) if m != n_samples: diff --git a/meegkit/utils/matrix.py b/meegkit/utils/matrix.py index 4c4c733e..5b73746b 100644 --- a/meegkit/utils/matrix.py +++ b/meegkit/utils/matrix.py @@ -637,7 +637,7 @@ def _check_shifts(shifts, allow_floats=False): """Check shifts.""" types = (int, np.int_) if allow_floats: - types += (float, np.float_) + types += (float, np.float64) if not isinstance(shifts, (np.ndarray, list, type(None)) + types): raise AttributeError("shifts should be a list, an array or an int") if isinstance(shifts, (list, ) + types): diff --git a/meegkit/utils/sig.py b/meegkit/utils/sig.py index 82473f8d..4451407d 100644 --- a/meegkit/utils/sig.py +++ b/meegkit/utils/sig.py @@ -276,18 +276,17 @@ def spectral_envelope(x, sfreq, lowpass=32): return y[len(a):-len(b)] -def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, show=False): +def gaussfilt(data, sfreq, f, fwhm, n_harm=1, shift=0, return_empvals=False, show=False): """Narrow-band filter via frequency-domain Gaussian. - Empirical frequency and FWHM depend on the sampling rate and the - number of time points, and may thus be slightly different from - the requested values. + Empirical frequency and FWHM depend on the sampling rate and the number of + time points, and may thus be slightly different from the requested values. Parameters ---------- data : ndarray EEG data, shape=(n_samples, n_channels[, ...]) - srate : int + sfreq : int Sampling rate in Hz. f : float Break frequency of filter. @@ -318,7 +317,7 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, sho assert (fwhm >= 0), "FWHM must be greater than 0" # frequencies - hz = np.fft.fftfreq(data.shape[0], 1. / srate) + hz = np.fft.fftfreq(data.shape[0], 1. / sfreq) empVals = np.zeros((2,)) # compute empirical frequency and standard deviation diff --git a/meegkit/utils/testing.py b/meegkit/utils/testing.py index e715028e..86f7ad97 100644 --- a/meegkit/utils/testing.py +++ b/meegkit/utils/testing.py @@ -6,7 +6,7 @@ def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, - n_bad_chans=1, SNR=.1, fline=1, t0=None, show=False): + n_bad_chans=1, SNR=.1, fline=0.1, t0=None, show=False): """Create synthetic data. Parameters @@ -24,7 +24,7 @@ def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, t0 : int Onset sample of artifact. fline : float - Normalized frequency of artifact (freq/samplerate), (default=1). + Normalized frequency of artifact (freq/samplerate), (default=0.1). Returns ------- diff --git a/pyproject.toml b/pyproject.toml index 0301fb52..df868ce2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,18 +66,18 @@ write-changes = false # Linter configuration # ################################## [tool.ruff] -select = ["D", "E", "F", "B", "Q", "NPY", "I", "ICN", "UP"] +lint.select = ["D", "E", "F", "B", "Q", "NPY", "I", "ICN", "UP"] line-length = 90 target-version = "py310" -ignore-init-module-imports = true -ignore = ["E731", "B006", "B028", "UP038", "D100", "D105", "D212"] +lint.ignore-init-module-imports = true +lint.ignore = ["E731", "B006", "B028", "UP038", "D100", "D105", "D212"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] "test_*.py" = ["D101", "D102", "D103", "D"] "example_*.py" = ["D205", "D400", "D212"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" ################################## diff --git a/tests/test_coherence.py b/tests/test_coherence.py new file mode 100644 index 00000000..5ad2ee91 --- /dev/null +++ b/tests/test_coherence.py @@ -0,0 +1,152 @@ +"""Test bicoherence functions.""" +import matplotlib.pyplot as plt +import numpy as np +import pytest +from scipy.fftpack import next_fast_len + +from meegkit.utils.coherence import ( + cross_coherence, + plot_polycoherence, + plot_polycoherence_1d, + plot_signal, + polycoherence_0d, + polycoherence_1d, + polycoherence_1d_sum, + polycoherence_2d, +) + + +@pytest.mark.parametrize("norm", [None, 2]) +def test_coherence(norm, show=True): + rng = np.random.default_rng(54326) + + # create signal with 5Hz and 8Hz components and noise, as well as a + # 5Hz-8Hz interaction + N = 10001 + kw = dict(nperseg=N // 4, noverlap=N // 20, nfft=next_fast_len(N // 2)) + t = np.linspace(0, 100, N) + fs = 1 / (t[1] - t[0]) + s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz + s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz + noise = 5 * rng.normal(0, 1, N) + signal = s1 + s2 + noise + 0.5 * s1 * s2 + + # bicoherence + # ---------------------------------------------r---------------------------- + + # 0D local bicoherence (fixed frequencies) + p5_8 = polycoherence_0d(signal, fs, [5, 8], **kw) + p5_6 = polycoherence_0d(signal, fs, [5, 6], **kw) + p5_6_8 = polycoherence_0d(signal, fs, [5, 6, 8], **kw) + if norm is not None: + print(f"bicoherence for f1=5Hz, f2=8Hz: {p5_8:.2f}") + print(f"bicoherence for f1=5Hz, f2=6Hz: {p5_6:.2f}") + print(f"bicoherence for f1=5Hz, f2=6Hz, f3=7Hz: {p5_6_8:.2f}") + assert p5_8 > 0.85 > p5_6 > p5_6_8 # 5Hz and 7Hz are coherent, 5Hz and 6Hz not + + # 1D bicoherence (fixed f2) + freqs, coh1d = polycoherence_1d(signal, fs, [5], **kw) + + # 1D bicoherence with sum (fixed f1+f2) + freqs1dsum, coh1dsum = polycoherence_1d_sum(signal, fs, 13, **kw) + + # 2D bicoherence, span all frequencies + # assert peaks at intersection of 5Hz and 8Hz, and 8-5=3Hz + freqs1, freqs2, coh2d = polycoherence_2d(signal, fs, **kw) + + if norm is not None: + assert np.max(coh2d) > 0.85 + assert np.abs(coh2d[freqs1 == 5, freqs2 == 8]) > 0.85 + assert np.abs(coh2d[freqs1 == 5, freqs2 == 3]) > 0.85 + + + if show: + # Plot signal + plot_signal(t, signal) + plt.suptitle("signal and spectrum for bicoherence tests") + + # Plot bicoherence + plot_polycoherence(freqs1, freqs2, coh2d) + plt.suptitle("bicoherence") + + # Plot bicoherence 1D + plot_polycoherence_1d(freqs, coh1d) + plt.suptitle("bicoherence for f2=5Hz (column, expected 3Hz, 8Hz)") + + # Plot bicoherence for f1+f2=13Hz + plot_polycoherence_1d(freqs1dsum, coh1dsum) + plt.suptitle("bicoherence for f1+f2=13Hz (diagonal, expected 5Hz, 8Hz)") + plt.show() + + + # bicoherence with synthetic signal + # ------------------------------------------------------------------------- + s3 = 4 * np.cos(2 * np.pi * 1 * t + 0.1) + s5 = 0.4 * np.cos(2 * np.pi * 0.2 * t + 1) + signal = s2 + s3 + s5 - 0.5 * s2 * s3 * s5 + noise + + synthetic = ((0.2, 10, 1), ) + p02_1_8 = polycoherence_0d(signal, fs, [0.2, 1, 8], synthetic=None, + norm=norm, **kw) + p02_1_8s = polycoherence_0d(signal, fs, [0.2, 1, 8], synthetic=synthetic, + norm=norm, **kw) + if norm is not None: + print(f"coherence for f1=0.02Hz, f2=1Hz, f3=7Hz: {p02_1_8:.2f}") + print(f"coherence for f1=0.02Hz (synthetic), f2=1Hz, f3=7Hz: {p02_1_8s:.2f}") + assert p02_1_8s > 0.9 > p02_1_8 + + result = polycoherence_2d(signal, fs, [0.02], synthetic=synthetic, norm=norm, **kw) + + if show: + plot_signal(t, signal) + plt.suptitle("signal and spectrum for tricoherence with synthetic signals") + plt.tight_layout() + + plot_polycoherence(*result) + plt.suptitle("tricoherence with f3=0.2Hz (synthetic)") + plt.show() + + # cross-coherence + # ------------------------------------------------------------------------- + s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz + s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz + signal1 = s1 + 4 * rng.normal(0, 1, N) + signal2 = s2 + s1 + 5 * rng.normal(0, 1, N) + f, Cxy = cross_coherence(signal1, signal2, fs, norm=None, **kw) + + if show: + plot_polycoherence(f, f, Cxy) + plt.suptitle("cross-coherence between 5Hz and 5+8Hz signals") + plt.tight_layout() + plt.show() + +@pytest.mark.parametrize("shape", [(1001,), (3, 1001), (17, 3, 1001)]) +def test_coherence_shapes(shape): + """Test coherence functions with 1D, 2D or 3D input data.""" + rng = np.random.default_rng(54326) + + # create signal with 5Hz and 8Hz components and noise, as well as a + # 5Hz-8Hz interaction + N = shape[-1] + t = np.linspace(0, 100, N) + fs = 1 / (t[1] - t[0]) + s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz + s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz + noise = 5 * rng.normal(0, 1, shape) + signal = np.broadcast_to(s1 + s2, shape) + noise + + assert signal.shape == shape + + B = polycoherence_0d(signal, fs, [5, 8]) + assert B.shape == shape[:-1] + f, B = polycoherence_1d(signal, fs, [8]) + assert B.shape == shape[:-1] + f.shape + f, B = polycoherence_1d_sum(signal, fs, 13) + assert B.shape == shape[:-1] + f.shape + f1, f2, B = polycoherence_2d(signal, fs) + assert B.shape == shape[:-1] + f1.shape + f2.shape + + +if __name__ == "__main__": + pytest.main([__file__]) + # test_coherence(2, False) \ No newline at end of file diff --git a/tests/test_cov.py b/tests/test_cov.py index ce69b470..ff24e496 100644 --- a/tests/test_cov.py +++ b/tests/test_cov.py @@ -12,7 +12,7 @@ def test_tscov(): # Compare 0-lag case with numpy.cov() c1, n1 = tscov(x, [0]) - c2 = np.cov(x, bias=True) + c2 = np.cov(x, bias=False) assert_almost_equal(c1 / n1, c2) # Compare 0-lag case with numpy.cov() @@ -89,6 +89,7 @@ def test_convmtx(): ) if __name__ == "__main__": - # import pytest - # pytest.main([__file__]) - test_convmtx() + import pytest + pytest.main([__file__]) + # test_tscov() + # test_convmtx() diff --git a/tests/test_dss.py b/tests/test_dss.py index 5b354a66..9abdeacc 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -47,7 +47,7 @@ def test_dss1(show=True): n_samples = 300 data, source = create_line_data(n_samples=n_samples, fline=.01) - todss, _, pwr0, pwr1 = dss.dss1(data, weights=None, ) + todss, _, pwr0, pwr1 = dss.dss1(data, weights=None) z = fold(np.dot(unfold(data), todss), epoch_size=n_samples) best_comp = np.mean(z[:, 0, :], -1) @@ -102,6 +102,7 @@ def _plot(x): ax[1].set_title("after") plt.show() + # 2D case, n_outputs == 1 out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep) _plot(out) @@ -122,7 +123,7 @@ def _plot(x): artifact = artifact ** 3 s = x + 10 * artifact out, _ = dss.dss_line(s, fline, sr, nremove=1) - + plt.close("all") def test_dss_line_iter(): """Test line noise removal.""" @@ -171,6 +172,7 @@ def _plot(before, after): x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=2, noise_dim=10, SNR=2, fline=fline / sr) out, _ = dss.dss_line_iter(x, fline, sr, show=False) + plt.close("all") def profile_dss_line(nkeep): diff --git a/tests/test_filters.py b/tests/test_filters.py index 93f42341..8531b90c 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -43,15 +43,15 @@ def test_noisy_signal(show=True): ax[0, 0].set_title("Signal and its Hilbert phase") ax[1, 0].plot(time, gtp, lw=.75, label="Ground truth") - ax[1, 0].plot(time, nrp, lw=.75, label=r"$\phi_N$") + ax[1, 0].plot(time, nrp, lw=.75, label=r"$\\phi_N$") ax[1, 0].set_ylabel(r"$\phi_N$") ax[1, 0].set_ylim([-np.pi, np.pi]) ax[1, 0].set_title("Nonresonant oscillator") ax[2, 0].plot(time, gtp, lw=.75, label="Ground truth") - ax[2, 0].plot(time, rp, lw=.75, label=r"$\phi_N$") + ax[2, 0].plot(time, rp, lw=.75, label=r"$\\phi_N$") ax[2, 0].set_ylim([-np.pi, np.pi]) - ax[2, 0].set_ylabel("$\phi_H - \phi_R$") + ax[2, 0].set_ylabel(r"$\phi_H - \phi_R$") ax[2, 0].set_xlabel("Time") ax[2, 0].set_title("Resonant oscillator") @@ -132,7 +132,7 @@ def test_all_alg(show=False): ax[3, 0].plot(time, r_phi_dif, lw=.75) ax[3, 0].axhline(0, color="k", ls=":", zorder=-1) ax[3, 0].set_ylim([-np.pi, np.pi]) - ax[3, 0].set_ylabel("$\phi_H - \phi_R$") + ax[3, 0].set_ylabel(r"$\phi_H - \phi_R$") ax[3, 0].set_xlabel("Time") ax[3, 0].set_title("Resonant oscillator") @@ -216,5 +216,5 @@ def generate_noisy_signal(npt=40000, fs=100, noise=0.1): if __name__ == "__main__": # Run the model_data_all_alg function - test_all_alg() - # test_noisy_signal() \ No newline at end of file + test_all_alg(True) + # test_noisy_signal(True) \ No newline at end of file diff --git a/tests/test_signal.py b/tests/test_signal.py index 9e9688ec..c0e654c3 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,7 +1,9 @@ """Test signal utils.""" +import matplotlib.pyplot as plt import numpy as np -from scipy.signal import butter, freqz, lfilter +from scipy.signal import butter, freqz, hilbert, lfilter +from meegkit.phase import ECHT from meegkit.utils.sig import stmcb, teager_kaiser rng = np.random.default_rng(9) @@ -51,6 +53,92 @@ def test_stcmb(show=True): np.testing.assert_allclose(h, hh, rtol=2) # equal to 2% + +def test_echt(show=False): + """Test Endpoint-corrected Hilbert transform (ECHT) phase estimation.""" + rng = np.random.default_rng(38872) + + # Build data + # ------------------------------------------------------------------------- + # First, we generate a multi-component signal with amplitude and phase + # modulations, as described in the paper [1]_. + f0 = 2 + filt_BW = f0 / 2 + N = 500 + sfreq = 200 + time = np.linspace(0, N / sfreq, N) + X = np.cos(2 * np.pi * f0 * time - np.pi / 4) + phase_true = np.angle(hilbert(X)) + X += rng.normal(0, 0.5, N) # Add noise + + # Compute phase and amplitude + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We compute the Hilbert phase, as well as the phase obtained with the ECHT + # filter. + # phase_hilbert = np.angle(hilbert(X)) # Hilbert phase + + # Compute ECHT-filtered signal + l_freq = f0 - filt_BW / 2 + h_freq = f0 + filt_BW / 2 + echt = ECHT(l_freq, h_freq, sfreq) + + Xf = echt.fit_transform(X) + phase_echt = np.angle(Xf).squeeze() + # phase_true = np.roll(phase_true, 1) + if show: + fig, ax = plt.subplots(3, 1, figsize=(8, 5)) + ax[0].plot(time, X) + ax[0].set_xlabel("Time (s)") + ax[0].set_title("Test signal") + ax[0].set_ylabel("Amplitude") + + ax[1].plot(time, phase_true, label="True phase", ls=":") + ax[1].plot(time, phase_echt, label="ECHT phase", lw=.5, alpha=.8) + ax[1].set_title("Phase") + ax[1].set_ylabel("Amplitude") + ax[1].set_xlabel("Time (s)") + ax[1].legend(loc="upper right", fontsize="small") + + ax[2].plot(time, np.unwrap(phase_true - phase_echt.squeeze()), + label="Phase error") + ax[2].set_title("Phase error") + ax[2].set_ylabel("Error") + + plt.tight_layout() + plt.show() + + mae = (np.abs(np.unwrap(phase_true - phase_echt.squeeze())) > np.pi / 6).sum() / N + assert mae < 0.1, mae + +def test_echt_nd(): + """Test ECHT with ndim > 1.""" + rng = np.random.default_rng(38872) + + # Build data + # ------------------------------------------------------------------------- + # First, we generate a multi-component signal with amplitude and phase + # modulations, as described in the paper [1]_. + f0 = 2 + filt_BW = f0 / 2 + N = 500 + sfreq = 200 + time = np.linspace(0, N / sfreq, N) + X = np.cos(2 * np.pi * f0 * time - np.pi / 4)[:, None] + X = X + rng.normal(0, 0.5, (N, 2)) # Add noise + + l_freq = f0 - filt_BW / 2 + h_freq = f0 + filt_BW / 2 + echt = ECHT(l_freq, h_freq, sfreq) + + Xf = echt.fit_transform(X) + phase_echt = np.angle(Xf) + + assert phase_echt.shape == (N, 2) + + if __name__ == "__main__": - test_teager_kaiser() - test_stcmb() + import pytest + pytest.main([__file__]) + # test_teager_kaiser() + # test_stcmb() + # test_echt() \ No newline at end of file