diff --git a/README.md b/README.md
index 3a120144..f47a837d 100644
--- a/README.md
+++ b/README.md
@@ -154,11 +154,19 @@ and was adapted to python by [Giuseppe Ferraro](mailto:giuseppe.ferraro@isae-sup
2022 Sep 27;22(19):7314. https://doi.org/10.3390/s22197314.
```
-### 6. Real-Time Phase Estimation
+### 6. Phase Estimation
-This code is based on the Matlab implementation from [Michael Rosenblum](http://www.stat.physik.uni-potsdam.de/~mros), and its corresponding paper [1].
+The oscillator code is based on the Matlab implementation from [Michael
+Rosenblum](http://www.stat.physik.uni-potsdam.de/~mros), and its corresponding
+paper [1]. The Endpoint Corrected Hilbert Transform (ECHT) method was adapted
+from [2].
```sql
-[1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of phase and amplitude with application to neural data. Sci Rep 11, 18037 (2021). https://doi.org/10.1038/s41598-021-97560-5
+[1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of phase
+ and amplitude with application to neural data. Sci Rep 11, 18037 (2021).
+ https://doi.org/10.1038/s41598-021-97560-5
+[2] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre, A.,
+ ... & Grossman, N. (2021). Non-invasive suppression of essential tremor via
+ phase-locked disruption of its temporal coherence. Nature communications, 12(1), 363.
```
diff --git a/doc/conf.py b/doc/conf.py
index 37568fef..9b4c3d02 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -26,7 +26,7 @@
# -- Project information -----------------------------------------------------
project = "MEEGkit"
-copyright = "2023, Nicolas Barascud"
+copyright = "2024, Nicolas Barascud"
author = "Nicolas Barascud"
release = meegkit.__version__
version = meegkit.__version__
@@ -63,7 +63,7 @@
"show-inheritance": True,
"exclude-members": "__weakref__"
}
-numpydoc_show_class_members = True
+numpydoc_show_class_members = False
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
@@ -129,3 +129,5 @@
"ignore_pattern": "config.py",
"run_stale_examples": False,
}
+
+suppress_warnings = ["config.cache"]
\ No newline at end of file
diff --git a/doc/index.rst b/doc/index.rst
index fd2fe078..34717b6e 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -46,7 +46,6 @@ Here is a list of the methods and techniques available in ``meegkit``:
~meegkit.tspca
~meegkit.utils
-
Examples gallery
----------------
diff --git a/doc/modules/meegkit.phase.rst b/doc/modules/meegkit.phase.rst
new file mode 100644
index 00000000..4a364f83
--- /dev/null
+++ b/doc/modules/meegkit.phase.rst
@@ -0,0 +1,23 @@
+meegkit.phase
+=============
+
+.. automodule:: meegkit.phase
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ NonResOscillator
+ ResOscillator
+ Device
+ ECHT
+
+ .. rubric:: Functions
+
+ .. autosummary::
+
+ locking_based_phase
+ rk
+ init_coefs
+ one_step_oscillator
+ one_step_integrator
diff --git a/doc/modules/meegkit.utils.rst b/doc/modules/meegkit.utils.rst
index 9797de75..661ee49c 100644
--- a/doc/modules/meegkit.utils.rst
+++ b/doc/modules/meegkit.utils.rst
@@ -6,11 +6,14 @@ meegkit.utils
.. autosummary::
auditory
+ buffer
+ coherence
covariances
denoise
matrix
sig
stats
+ trca
|
@@ -23,6 +26,28 @@ Auditory
.. autosummary::
+|
+
+----
+
+Buffer
+------
+.. automodule:: meegkit.utils.buffer
+
+ .. autosummary::
+
+
+|
+
+----
+
+Coherence
+---------
+.. automodule:: meegkit.utils.coherence
+
+ .. autosummary::
+
+
|
----
diff --git a/doc/sg_execution_times.rst b/doc/sg_execution_times.rst
new file mode 100644
index 00000000..b6f64701
--- /dev/null
+++ b/doc/sg_execution_times.rst
@@ -0,0 +1,73 @@
+
+:orphan:
+
+.. _sphx_glr_sg_execution_times:
+
+
+Computation times
+=================
+**00:29.509** total execution time for 13 files **from all galleries**:
+
+.. container::
+
+ .. raw:: html
+
+
+
+
+
+
+
+ .. list-table::
+ :header-rows: 1
+ :class: table table-striped sg-datatable
+
+ * - Example
+ - Time
+ - Mem (MB)
+ * - :ref:`sphx_glr_auto_examples_example_trca.py` (``../examples/example_trca.py``)
+ - 00:11.533
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_dss_line.py` (``../examples/example_dss_line.py``)
+ - 00:07.115
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_phase_estimation.py` (``../examples/example_phase_estimation.py``)
+ - 00:07.064
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_mcca.py` (``../examples/example_mcca.py``)
+ - 00:01.182
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_mcca_2.py` (``../examples/example_mcca_2.py``)
+ - 00:00.535
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_asr.py` (``../examples/example_asr.py``)
+ - 00:00.512
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_ress.py` (``../examples/example_ress.py``)
+ - 00:00.459
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_detrend.py` (``../examples/example_detrend.py``)
+ - 00:00.259
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_star_dss.py` (``../examples/example_star_dss.py``)
+ - 00:00.246
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_star.py` (``../examples/example_star.py``)
+ - 00:00.197
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_dss.py` (``../examples/example_dss.py``)
+ - 00:00.146
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_echt.py` (``../examples/example_echt.py``)
+ - 00:00.132
+ - 0.0
+ * - :ref:`sphx_glr_auto_examples_example_dering.py` (``../examples/example_dering.py``)
+ - 00:00.129
+ - 0.0
diff --git a/examples/example_echt.ipynb b/examples/example_echt.ipynb
new file mode 100644
index 00000000..156e0807
--- /dev/null
+++ b/examples/example_echt.ipynb
@@ -0,0 +1,97 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Endpoint-corrected Hilbert transform (ECHT) phase estimation\n\nThis example shows how to causally estimate the phase of a signal using\n\nUses `meegkit.phase.ECHT()`.\n\n## References\n.. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre,\n A., ... & Grossman, N. (2021). Non-invasive suppression of essential tremor\n via phase-locked disruption of its temporal coherence. Nature\n communications, 12(1), 363.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.signal import hilbert\n\nfrom meegkit.phase import ECHT\n\nrng = np.random.default_rng(38872)\n\nplt.rcParams[\"axes.grid\"] = True\nplt.rcParams[\"grid.linestyle\"] = \":\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build data\nFirst, we generate a multi-component signal with amplitude and phase\nmodulations, as described in the paper [1]_.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "f0 = 2\n\nN = 500\nsfreq = 100\ntime = np.linspace(0, N / sfreq, N)\nX = np.cos(2 * np.pi * f0 * time - np.pi / 4)\nphase_true = np.angle(hilbert(X))\nX += rng.normal(0, 0.5, N) # Add noise"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Compute phase and amplitude\nWe compute the Hilbert phase, as well as the phase obtained with the ECHT\nfilter.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "phase_hilbert = np.angle(hilbert(X)) # Hilbert phase\n\n# Compute ECHT-filtered signal\nfilt_BW = f0 / 2\nl_freq = f0 - filt_BW / 2\nh_freq = f0 + filt_BW / 2\necht = ECHT(l_freq, h_freq, sfreq)\n\nXf = echt.fit_transform(X)\nphase_echt = np.angle(Xf)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Visualize signal\nPlot the results\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(3, 1, figsize=(8, 6))\nax[0].plot(time, X)\nax[0].set_xlabel(\"Time (s)\")\nax[0].set_title(\"Test signal\")\nax[0].set_ylabel(\"Amplitude\")\nax[1].psd(X, Fs=sfreq, NFFT=2048*4, noverlap=sfreq)\nax[1].set_ylabel(\"PSD (dB/Hz)\")\nax[1].set_title(\"Test signal's Fourier spectrum\")\nax[2].plot(time, phase_true, label=\"True phase\", ls=\":\")\nax[2].plot(time, phase_echt, label=\"ECHT phase\", lw=.5, alpha=.8)\nax[2].plot(time, phase_hilbert, label=\"Hilbert phase\", lw=.5, alpha=.8)\nax[2].set_title(\"Phase\")\nax[2].set_ylabel(\"Amplitude\")\nax[2].set_xlabel(\"Time (s)\")\nax[2].legend(loc=\"upper right\", fontsize=\"small\")\nplt.tight_layout()\nplt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/example_echt.py b/examples/example_echt.py
new file mode 100644
index 00000000..3ed111cf
--- /dev/null
+++ b/examples/example_echt.py
@@ -0,0 +1,81 @@
+"""
+Endpoint-corrected Hilbert transform (ECHT) phase estimation
+============================================================
+
+This example shows how to causally estimate the phase of a signal using the
+Endpoint-corrected Hilbert transform (ECHT) [1]_.
+
+Uses `meegkit.phase.ECHT()`.
+
+References
+----------
+.. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X., Latorre,
+ A., ... & Grossman, N. (2021). Non-invasive suppression of essential tremor
+ via phase-locked disruption of its temporal coherence. Nature
+ communications, 12(1), 363.
+
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.signal import hilbert
+
+from meegkit.phase import ECHT
+
+rng = np.random.default_rng(38872)
+
+plt.rcParams["axes.grid"] = True
+plt.rcParams["grid.linestyle"] = ":"
+
+###############################################################################
+# Build data
+# -----------------------------------------------------------------------------
+# First, we generate a multi-component signal with amplitude and phase
+# modulations, as described in the paper [1]_.
+f0 = 2
+
+N = 500
+sfreq = 100
+time = np.linspace(0, N / sfreq, N)
+X = np.cos(2 * np.pi * f0 * time - np.pi / 4)
+phase_true = np.angle(hilbert(X))
+X += rng.normal(0, 0.5, N) # Add noise
+
+###############################################################################
+# Compute phase and amplitude
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# We compute the Hilbert phase, as well as the phase obtained with the ECHT
+# filter.
+phase_hilbert = np.angle(hilbert(X)) # Hilbert phase
+
+# Compute ECHT-filtered signal
+filt_BW = f0 / 2
+l_freq = f0 - filt_BW / 2
+h_freq = f0 + filt_BW / 2
+echt = ECHT(l_freq, h_freq, sfreq)
+
+Xf = echt.fit_transform(X)
+phase_echt = np.angle(Xf)
+
+###############################################################################
+# Visualize signal
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Here we plot the original signal, its Fourier spectrum, and the phase obtained
+# with the Hilbert transform and the ECHT filter. The ECHT filter provides a
+# much smoother phase estimate than the Hilbert transform
+fig, ax = plt.subplots(3, 1, figsize=(8, 6))
+ax[0].plot(time, X)
+ax[0].set_xlabel("Time (s)")
+ax[0].set_title("Test signal")
+ax[0].set_ylabel("Amplitude")
+ax[1].psd(X, Fs=sfreq, NFFT=2048*4, noverlap=sfreq)
+ax[1].set_ylabel("PSD (dB/Hz)")
+ax[1].set_title("Test signal's Fourier spectrum")
+ax[2].plot(time, phase_true, label="True phase", ls=":")
+ax[2].plot(time, phase_echt, label="ECHT phase", lw=.5, alpha=.8)
+ax[2].plot(time, phase_hilbert, label="Hilbert phase", lw=.5, alpha=.8)
+ax[2].set_title("Phase")
+ax[2].set_ylabel("Amplitude")
+ax[2].set_xlabel("Time (s)")
+ax[2].legend(loc="upper right", fontsize="small")
+plt.tight_layout()
+plt.show()
diff --git a/examples/example_phase_estimation.ipynb b/examples/example_phase_estimation.ipynb
index 43d8cc99..605a1912 100644
--- a/examples/example_phase_estimation.ipynb
+++ b/examples/example_phase_estimation.ipynb
@@ -40,7 +40,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Vizualize signal\nPlot the test signal's Fourier spectrum\n\n"
+ "### Visualize signal\nPlot the test signal's Fourier spectrum\n\n"
]
},
{
@@ -83,7 +83,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The first row shows the test signal $s$ and its Hilbert amplitude $a_H$ ; one\ncan see that ah does not represent a good envelope for $s$. On the contrary,\nthe Hilbert-based phase estimation yields good results, and therefore we take\nit for the ground truth.\nRows 2-4 show the difference between the Hilbert phase and causally\nestimated phases ($\\phi_L$, $\\phi_N$, $\\phi_R$) are obtained by means of the\nlocking-based technique, non-resonant and resonant oscillator, respectively).\nThese panels demonstrate that the output of the developed causal algorithms\nis very close to the HT-phase. Notice that we show $\\phi_H - \\phi_N$\nmodulo $2\\pi$, since the phase difference is not bounded.\n\n"
+ "The first row shows the test signal $s$ and its Hilbert amplitude\n$a_H$ ; one can see that ah does not represent a good envelope for\n$s$. On the contrary, the Hilbert-based phase estimation yields good\nresults, and therefore we take it for the ground truth. Rows 2-4 show the\ndifference between the Hilbert phase and causally estimated phases\n($\\phi_L$, $\\phi_N$, $\\phi_R$) are obtained by means of the\nlocking-based technique, non-resonant and resonant oscillator, respectively).\nThese panels demonstrate that the output of the developed causal algorithms\nis very close to the HT-phase. Notice that we show $\\phi_H - \\phi_N$\nmodulo :math:`2\\pi, since the phase difference is not bounded.\n\n"
]
},
{
@@ -94,7 +94,7 @@
},
"outputs": [],
"source": [
- "f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8))\nax[0, 0].plot(time, s, time, ht_phase, lw=.75)\nax[0, 0].set_ylabel(r\"$s,\\phi_H$\")\nax[0, 0].set_title(\"Signal and its Hilbert phase\")\n\nax[1, 0].plot(time, lb_phi_dif, lw=.75)\nax[1, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[1, 0].set_ylabel(r\"$\\phi_H - \\phi_L$\")\nax[1, 0].set_ylim([-np.pi, np.pi])\nax[1, 0].set_title(\"Phase locking approach\")\n\nax[2, 0].plot(time, nr_phi_dif, lw=.75)\nax[2, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[2, 0].set_ylabel(r\"$\\phi_H - \\phi_N$\")\nax[2, 0].set_ylim([-np.pi, np.pi])\nax[2, 0].set_title(\"Nonresonant oscillator\")\n\nax[3, 0].plot(time, r_phi_dif, lw=.75)\nax[3, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[3, 0].set_ylim([-np.pi, np.pi])\nax[3, 0].set_ylabel(\"$\\phi_H - \\phi_R$\")\nax[3, 0].set_xlabel(\"Time\")\nax[3, 0].set_title(\"Resonant oscillator\")\n\nax[0, 1].plot(time, s, time, ht_ampl, lw=.75)\nax[0, 1].set_ylabel(r\"$s,a_H$\")\nax[0, 1].set_title(\"Signal and its Hilbert amplitude\")\n\nax[1, 1].axis(\"off\")\n\nax[2, 1].plot(time, s, time, nr_ampl, lw=.75)\nax[2, 1].set_ylabel(r\"$s,a_N$\")\nax[2, 1].set_title(\"Amplitudes\")\nax[2, 1].set_title(\"Nonresonant oscillator\")\n\nax[3, 1].plot(time, s, time, r_ampl, lw=.75)\nax[3, 1].set_xlabel(\"Time\")\nax[3, 1].set_ylabel(r\"$s,a_R$\")\nax[3, 1].set_title(\"Resonant oscillator\")\nplt.suptitle(\"Amplitude (right) and phase (left) estimation algorithms\")\nplt.tight_layout()\nplt.show()"
+ "f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8))\nax[0, 0].plot(time, s, time, ht_phase, lw=.75)\nax[0, 0].set_ylabel(r\"$s,\\phi_H$\")\nax[0, 0].set_title(\"Signal and its Hilbert phase\")\n\nax[1, 0].plot(time, lb_phi_dif, lw=.75)\nax[1, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[1, 0].set_ylabel(r\"$\\phi_H - \\phi_L$\")\nax[1, 0].set_ylim([-np.pi, np.pi])\nax[1, 0].set_title(\"Phase locking approach\")\n\nax[2, 0].plot(time, nr_phi_dif, lw=.75)\nax[2, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[2, 0].set_ylabel(r\"$\\phi_H - \\phi_N$\")\nax[2, 0].set_ylim([-np.pi, np.pi])\nax[2, 0].set_title(\"Nonresonant oscillator\")\n\nax[3, 0].plot(time, r_phi_dif, lw=.75)\nax[3, 0].axhline(0, color=\"k\", ls=\":\", zorder=-1)\nax[3, 0].set_ylim([-np.pi, np.pi])\nax[3, 0].set_ylabel(r\"$\\phi_H - \\phi_R$\")\nax[3, 0].set_xlabel(\"Time\")\nax[3, 0].set_title(\"Resonant oscillator\")\n\nax[0, 1].plot(time, s, time, ht_ampl, lw=.75)\nax[0, 1].set_ylabel(r\"$s,a_H$\")\nax[0, 1].set_title(\"Signal and its Hilbert amplitude\")\n\nax[1, 1].axis(\"off\")\n\nax[2, 1].plot(time, s, time, nr_ampl, lw=.75)\nax[2, 1].set_ylabel(r\"$s,a_N$\")\nax[2, 1].set_title(\"Amplitudes\")\nax[2, 1].set_title(\"Nonresonant oscillator\")\n\nax[3, 1].plot(time, s, time, r_ampl, lw=.75)\nax[3, 1].set_xlabel(\"Time\")\nax[3, 1].set_ylabel(r\"$s,a_R$\")\nax[3, 1].set_title(\"Resonant oscillator\")\nplt.suptitle(\"Amplitude (right) and phase (left) estimation algorithms\")\nplt.tight_layout()\nplt.show()"
]
}
],
@@ -114,7 +114,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.6"
+ "version": "3.12.2"
}
},
"nbformat": 4,
diff --git a/examples/example_phase_estimation.py b/examples/example_phase_estimation.py
index 8e5e19b7..b7ac9230 100644
--- a/examples/example_phase_estimation.py
+++ b/examples/example_phase_estimation.py
@@ -43,7 +43,7 @@
time = np.arange(npt) * dt
###############################################################################
-# Vizualize signal
+# Visualize signal
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Plot the test signal's Fourier spectrum
f, ax = plt.subplots(2, 1)
@@ -83,16 +83,16 @@
# Here we reproduce figure 1 from the original paper [1]_.
###############################################################################
-# The first row shows the test signal $s$ and its Hilbert amplitude $a_H$ ; one
-# can see that ah does not represent a good envelope for $s$. On the contrary,
-# the Hilbert-based phase estimation yields good results, and therefore we take
-# it for the ground truth.
-# Rows 2-4 show the difference between the Hilbert phase and causally
-# estimated phases ($\phi_L$, $\phi_N$, $\phi_R$) are obtained by means of the
+# The first row shows the test signal :math:`s` and its Hilbert amplitude
+# :math:`a_H` ; one can see that ah does not represent a good envelope for
+# :math:`s`. On the contrary, the Hilbert-based phase estimation yields good
+# results, and therefore we take it for the ground truth. Rows 2-4 show the
+# difference between the Hilbert phase and causally estimated phases
+# (:math:`\phi_L`, :math:`\phi_N`, :math:`\phi_R`) are obtained by means of the
# locking-based technique, non-resonant and resonant oscillator, respectively).
# These panels demonstrate that the output of the developed causal algorithms
-# is very close to the HT-phase. Notice that we show $\phi_H - \phi_N$
-# modulo $2\pi$, since the phase difference is not bounded.
+# is very close to the HT-phase. Notice that we show :math:`\phi_H - \phi_N`
+# modulo :math:`2\pi`, since the phase difference is not bounded.
f, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(12, 8))
ax[0, 0].plot(time, s, time, ht_phase, lw=.75)
ax[0, 0].set_ylabel(r"$s,\phi_H$")
@@ -113,7 +113,7 @@
ax[3, 0].plot(time, r_phi_dif, lw=.75)
ax[3, 0].axhline(0, color="k", ls=":", zorder=-1)
ax[3, 0].set_ylim([-np.pi, np.pi])
-ax[3, 0].set_ylabel("$\phi_H - \phi_R$")
+ax[3, 0].set_ylabel(r"$\phi_H - \phi_R$")
ax[3, 0].set_xlabel("Time")
ax[3, 0].set_title("Resonant oscillator")
diff --git a/meegkit/asr.py b/meegkit/asr.py
index 9d6ad27d..29ce33ca 100755
--- a/meegkit/asr.py
+++ b/meegkit/asr.py
@@ -551,7 +551,7 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method="riemann",
detrend : bool
If True, detrend filtered data (default=False).
method : {'euclid', 'riemann'}
- Metric to compute the covariance matric average.
+ Metric to compute the covariance matrix average.
Returns
-------
diff --git a/meegkit/dss.py b/meegkit/dss.py
index 8105570b..65ce19ee 100644
--- a/meegkit/dss.py
+++ b/meegkit/dss.py
@@ -48,14 +48,11 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
Power per component (averaged).
"""
- n_trials = theshapeof(X)[-1]
-
# if demean: # remove weighted mean
# X = demean(X, weights)
# weighted mean over trials (--> bias function for DSS)
xx, ww = mean_over_trials(X, weights)
- ww /= n_trials
# covariance of raw and biased X
c0, nc0 = tscov(X, None, weights)
diff --git a/meegkit/phase.py b/meegkit/phase.py
index 8e3cc87f..39e19b0b 100644
--- a/meegkit/phase.py
+++ b/meegkit/phase.py
@@ -9,11 +9,13 @@
be slow for large input arrays (n_channels >> 10), since an individual
oscillator is instantiated for each channel.
-.. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation
- of phase and amplitude with application to neural data. Sci Rep 11, 18037
+.. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation of
+ phase and amplitude with application to neural data. Sci Rep 11, 18037
(2021). https://doi.org/10.1038/s41598-021-97560-5
"""
import numpy as np
+from scipy.fftpack import fft, fftshift, ifft, ifftshift, next_fast_len
+from scipy.signal import butter, freqz
from meegkit.utils.buffer import Buffer
@@ -104,13 +106,14 @@ def step(self, sprev, s, snew):
class NonResOscillator:
"""Real-time measurement of phase and amplitude using non-resonant oscillator.
- This estimator relies on the resonance effect. The measuring “device” consists
- of two linear damped oscillators. The oscillators' frequency is much larger
- than the frequency of the signal, i.e., the system is far from resonance.
- We choose the damping parameters to ensure that (i) the phase of the first
- linear oscillator equals that of the input and that (ii) amplitude of the
- second one and the input relate by a known constant multiplicator. The
- technique yields both phase and amplitude of the input signal.
+ This estimator relies on the resonance effect. The measuring “device”
+ consists of two linear damped oscillators. The oscillators' frequency is
+ much larger than the frequency of the signal, i.e., the system is far from
+ resonance. We choose the damping parameters to ensure that (i) the phase of
+ the first linear oscillator equals that of the input and that (ii)
+ amplitude of the second one and the input relate by a known constant
+ multiplicator. The technique yields both phase and amplitude of the input
+ signal.
This estimator includes an automated frequency-tuning algorithm to adjust
to the a priori unknown signal frequency.
@@ -127,24 +130,23 @@ class NonResOscillator:
References
----------
.. [1] Rosenblum, M., Pikovsky, A., Kühn, A.A. et al. Real-time estimation
- of phase and amplitude with application to neural data. Sci Rep 11, 18037
- (2021). https://doi.org/10.1038/s41598-021-97560-5
+ of phase and amplitude with application to neural data. Sci Rep 11,
+ 18037 (2021). https://doi.org/10.1038/s41598-021-97560-5
"""
- def __init__(self, fs=250, nu=1.1):
+ def __init__(self, fs=250, nu=1.1, alpha_a=6.0, alpha_p=0.2, update_factor=5):
# Parameters of the measurement "devices"
self.dt = 1 / fs # Sampling interval
self.nu = nu # Rough estimate of the tremor frequency
self.om0 = 5 * nu # Oscillator frequency (estimation)
- self.alpha_a = 6.0 # Damping parameter for the "amplitude device"
- self.gamma_a = self.alpha_a / 2
- self.alpha_p = 0.2 # Damping parameter for the "phase device"
- self.gamma_p = self.alpha_p / 2
+ self.alpha_a = alpha_a # Damping parameter for the "amplitude device"
+ self.gamma_a = alpha_a / 2
+ self.alpha_p = alpha_p # Damping parameter for the "phase device"
+ self.gamma_p = alpha_p / 2
self.factor = np.sqrt((self.om0 ** 2 - nu ** 2) ** 2 + (self.alpha_a * nu) ** 2)
# Update parameters, and precomputed quantities
- update_factor = 5
self.memory = round(2 * np.pi / self.om0 / self.dt)
self.update_point = 2 * self.memory
self.update_step = round(self.memory / update_factor)
@@ -196,7 +198,7 @@ def transform(self, X):
continue # Skip the first two samples
# Amplitude estimation
- spp, sp, s = self.buffer.view(3)
+ spp, sp, s = self.buffer.view(3)[:, 0]
for ch in range(self.n_channels):
self.adevice[ch].step(spp, sp, s)
@@ -261,11 +263,12 @@ class ResOscillator:
(2021). https://doi.org/10.1038/s41598-021-97560-5
"""
- def __init__(self, fs=1000, nu=4.5, freq_adaptation=True):
+ def __init__(self, fs=1000, nu=4.5, update_factor=5, freq_adaptation=True,
+ assume_centered=False):
# Parameters of the measurement "device"
self.dt = 1 / fs # Sampling interval
- self.om0 = 1.1 # Angular frequency
+ self.om0 = nu # Angular frequency
self.alpha = 0.3 * self.om0
self.gamma = self.alpha / 2
@@ -273,7 +276,7 @@ def __init__(self, fs=1000, nu=4.5, freq_adaptation=True):
nperiods = 1 # Number of previous periods for frequency correction
npt_period = round(2 * np.pi / self.om0 / self.dt) # Number of points per period
self.memory = nperiods * npt_period # M points for frequency correction buffer
- self.update_factor = 5 # Number of frequency updates per period
+ self.update_factor = update_factor # Number of frequency updates per period
self.update_step = round(npt_period / self.update_factor)
self.updatepoint = 2 * self.memory
@@ -287,6 +290,7 @@ def __init__(self, fs=1000, nu=4.5, freq_adaptation=True):
self.buffer = None
self.runav = 0. # Initial guess for the dc-component
self.freq_adaptation = freq_adaptation
+ self.assume_centered = assume_centered
def _set_devices(self, n_channels):
# Set up the phase and amplitude "devices"
@@ -362,7 +366,9 @@ def transform(self, X):
self.osc[ch].init_coefs(om0, self.dt, self.gamma)
# Update running average
- self.runav = np.mean(self.buffer.view(self.memory), axis=0, keepdims=True)
+ if self.assume_centered is False:
+ self.runav = np.mean(self.buffer.view(self.memory), axis=0,
+ keepdims=True)
self.updatepoint += self.update_step # Point for the next update
return phase, ampl
@@ -619,3 +625,167 @@ def one_step_integrator(z, edelmu, mu, dt, spp, sp, s):
C0 = z + d
z = C0 * edelmu - d + b * dt - 2 * c * mu * dt + c * dt ** 2
return z
+
+
+class ECHT:
+ """Endpoint Corrected Hilbert Transform (ECHT).
+
+ See [1]_ for details.
+
+ Parameters
+ ----------
+ X : ndarray, shape=(n_samples, n_channels)
+ Time domain signal.
+ l_freq : float | None
+ Low-cutoff frequency of a bandpass causal filter. If None, the data is
+ only low-passed.
+ h_freq : float | None
+ High-cutoff frequency of a bandpass causal filter. If None, the data is
+ only high-passed.
+ sfreq : float
+ Sampling rate of time domain signal.
+ n_fft : int, optional
+ Length of analytic signal. If None, it defaults to the length of X.
+ filt_order : int, optional
+ Order of the filter. Default is 2.
+
+ Notes
+ -----
+ One common implementation of the Hilbert Transform uses a DFT (aka FFT)
+ as part of its computation. Inherent to the DFT is the assumption that
+ a finite sample of a signal is replicated infinitely in time, effectively
+ abutting the end of a sample with its replicated start. If the start and
+ end of the sample are not continuous with each other, distortions are
+ introduced by the DFT. Echt effectively smooths out this 'discontinuity'
+ by selectively deforming the start of the sample. It is hence most suited
+ for real-time applications in which the point/s of interest is/are the
+ most recent one/s (i.e. last) in the sample window.
+
+ We found that a filter bandwidth (BW=h_freq-l_freq) of up to half the
+ signal's central frequency works well.
+
+ References
+ ----------
+ .. [1] Schreglmann, S. R., Wang, D., Peach, R. L., Li, J., Zhang, X.,
+ Latorre, A., ... & Grossman, N. (2021). Non-invasive suppression of
+ essential tremor via phase-locked disruption of its temporal coherence.
+ Nature communications, 12(1), 363.
+
+ Examples
+ --------
+ >>> f0 = 2
+ >>> filt_BW = f0 / 2
+ >>> N = 1000
+ >>> sfreq = N / (2 * np.pi)
+ >>> t = np.arange(-2 * np.pi, 0, 1 / sfreq)
+ >>> X = np.cos(2 * np.pi * f0 * t - np.pi / 4)
+ >>> l_freq = f0 - filt_BW / 2
+ >>> h_freq = f0 + filt_BW / 2
+ >>> Xf = echt(X, l_freq, h_freq, sfreq)
+ """
+
+ def __init__(self, l_freq, h_freq, sfreq, n_fft=None, filt_order=2):
+ self.l_freq = l_freq
+ self.h_freq = h_freq
+ self.sfreq = sfreq
+ self.n_fft = n_fft
+ self.filt_order = filt_order
+
+ # attributes
+ self.h_ = None
+ self.coef_ = None
+
+ def fit(self, X, y=None):
+ """Fit the ECHT transform to the input signal.
+
+ Parameters
+ ----------
+ X : ndarray, shape=(n_samples, n_channels)
+ The input signal to be transformed.
+
+ """
+ if self.n_fft is None:
+ self.n_fft = next_fast_len(X.shape[0])
+
+ # Set the amplitude of the negative components of the FFT to zero and
+ # multiply the amplitudes of the positive components, apart from the
+ # zero-frequency component (DC) and Nyquist components, by 2.
+ #
+ # If the signal has an even number of elements n, the frequency components
+ # are:
+ # - n/2-1 negative elements,
+ # - one DC element,
+ # - n/2-1 positive elements and
+ # - one Nyquist element (in order).
+ #
+ # If the signal has an odd number of elements n, the frequency components
+ # are:
+ # - (n-1)/2 negative elements,
+ # - one DC element,
+ # - (n-1)/2 positive elements (in order).
+ # - no positive element corresponding to the Nyquist frequency.
+ self.h_ = np.zeros(self.n_fft)
+ self.h_[0] = 1
+ self.h_[1:(self.n_fft // 2) + 1] = 2
+ if self.n_fft % 2 == 0:
+ self.h_[self.n_fft // 2] = 1
+
+ # The frequency response vector is computed using freqz from the filter's
+ # impulse response function, computed by butter function, and the user
+ # defined low-cutoff frequency, high-cutoff frequency and sampling rate.
+ Wn = [self.l_freq / (self.sfreq / 2), self.h_freq / (self.sfreq / 2)]
+ b, a = butter(self.filt_order, Wn, btype="band")
+ T = 1 / self.sfreq * self.n_fft
+ filt_freq = np.ceil(np.arange(-self.n_fft / 2, self.n_fft / 2) / T)
+
+ self.coef_ = freqz(b, a, filt_freq, fs=self.sfreq)[1]
+ self.coef_ = self.coef_[:, None]
+
+ return self
+
+ def transform(self, X):
+ """Apply the ECHT transform to the input signal.
+
+ Parameters
+ ----------
+ X : ndarray, shape=(n_samples, n_channels)
+ The input signal to be transformed.
+
+ Returns
+ -------
+ Xf : ndarray, shape=(n_samples, n_channels)
+ The transformed signal (complex-valued).
+
+ """
+ if not np.isrealobj(X):
+ X = np.real(X)
+
+ # if not fitted
+ if self.h_ is None or self.coef_ is None:
+ self.fit(X)
+
+ if X.ndim == 1:
+ X = X[:, np.newaxis]
+
+ # Compute the FFT of the signal.
+ Xf = fft(X, self.n_fft, axis=0)
+
+ # In contrast to :meth:`scipy.signal.hilbert()`, the code then
+ # multiplies the array by a frequency response vector of a causal
+ # bandpass filter.
+ Xf = Xf * self.h_[:, None]
+
+ # The array is arranged, using fft_shift function, so that the zero-frequency
+ # component is at the center of the array, before the multiplication, and
+ # rearranged back so that the zero-frequency component is at the left of the
+ # array using ifft_shift(). Finally, the IFFT is computed.
+ Xf = fftshift(Xf)
+ Xf = Xf * self.coef_
+ Xf = ifftshift(Xf)
+ Xf = ifft(Xf, axis=0)
+
+ return Xf
+
+ def fit_transform(self, X, y=None):
+ """Fit the ECHT transform to the input signal and transform it."""
+ return self.fit(X).transform(X)
\ No newline at end of file
diff --git a/meegkit/tspca.py b/meegkit/tspca.py
index 677aea3b..95f98ec5 100644
--- a/meegkit/tspca.py
+++ b/meegkit/tspca.py
@@ -34,7 +34,7 @@ def tspca(X, shifts=None, keep=None, threshold=None, weights=None,
weights : array
Sample weights.
demean : bool
- If True, Epochs are centered before comuting PCA (default=0).
+ If True, Epochs are centered before computing PCA (default=0).
Returns
-------
diff --git a/meegkit/utils/__init__.py b/meegkit/utils/__init__.py
index 28b1cc6b..bccc3b75 100644
--- a/meegkit/utils/__init__.py
+++ b/meegkit/utils/__init__.py
@@ -1,6 +1,16 @@
"""Utility functions."""
from .auditory import AuditoryFilterbank, GammatoneFilterbank, erb2hz, erbspace, hz2erb
from .base import mldivide, mrdivide
+from .coherence import (
+ cross_coherence,
+ plot_polycoherence,
+ plot_polycoherence_1d,
+ plot_signal,
+ polycoherence_0d,
+ polycoherence_1d,
+ polycoherence_1d_sum,
+ polycoherence_2d,
+)
from .covariances import (
block_covariance,
convmtx,
diff --git a/meegkit/utils/asr.py b/meegkit/utils/asr.py
index dd0fa6b5..cab47ca9 100755
--- a/meegkit/utils/asr.py
+++ b/meegkit/utils/asr.py
@@ -93,7 +93,7 @@ def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1,
# we can generally skip the tail below the lower quantile
lower_min = np.min(quants)
# maximum width is the fit interval if all data is clean
- max_width = np.diff(quants)
+ max_width = np.diff(quants)[0]
# minimum width of the fit interval, as fraction of data
min_width = min_clean_fraction * max_width
diff --git a/meegkit/utils/buffer.py b/meegkit/utils/buffer.py
index 1cb2733f..d91614e4 100644
--- a/meegkit/utils/buffer.py
+++ b/meegkit/utils/buffer.py
@@ -22,14 +22,9 @@ class Buffer:
The number of channels of the buffer.
counter : int
The number of samples in the buffer.
- head : int
- The index of the most recent sample in the buffer.
- tail : int
- The index of the most recent read sample in the buffer.
_data : ndarray, shape (size, n_channels)
Data buffer.
-
"""
def __init__(self, size, n_channels=1):
@@ -87,6 +82,8 @@ def push(self, X):
self._head += n_samples
self.counter += n_samples
+ return self
+
def get_new_samples(self, n_samples=None):
"""Consume n_samples."""
if n_samples is None:
diff --git a/meegkit/utils/coherence.py b/meegkit/utils/coherence.py
new file mode 100644
index 00000000..0d13ab98
--- /dev/null
+++ b/meegkit/utils/coherence.py
@@ -0,0 +1,424 @@
+"""Signal coherence tools.
+
+Compute 2D, 1D, and 0D bicoherence, polycoherence, bispectrum, and polyspectrum.
+
+Bicoherence is a measure of the degree of phase coupling between different
+frequency components in a signal. It's essentially a normalized form of the
+bispectrum, which itself is a Fourier transform of the third-order cumulant of
+a time series:
+
+- 2D bicoherence is the most common form, where one looks at a
+ two-dimensional representation of the phase coupling between different
+ frequencies. It's a function of two frequency variables.
+- 1D bicoherence would imply a slice or a specific condition in the 2D
+ bicoherence, reducing it to a function of a single frequency variable. It
+ simplifies the analysis by looking at the relationship between a particular
+ frequency and its harmonics or other relationships
+- 0D bicoherence would imply a single value representing some average or
+ overall measure of phase coupling in the signal. It's a highly condensed
+ summary, which might represent the average bicoherence over all frequency
+ pairs, for instance.
+
+
+"""
+from collections.abc import Iterable
+
+import matplotlib.pyplot as plt
+import numpy as np
+from numpy.fft import rfft, rfftfreq
+from scipy.fftpack import next_fast_len
+from scipy.signal import spectrogram
+
+
+def cross_coherence(x1, x2, sfreq, norm=2, **kwargs):
+ """Compute the bispectral cross-coherence between two signals of same length.
+
+ Code adapted from [2]_.
+
+ Parameters
+ ----------
+ x1: array-like, shape=([n_channels, ]n_samples)
+ First signal.
+ x2: array-like, shape=([n_channels, ]n_samples)
+ Second signal.
+ sfreq: float
+ Sampling sfreq.
+ norm: int | None
+ Norm (default=2). If None, return bispectrum.
+
+ Returns
+ -------
+ f1: array-like
+ Frequency axis.
+ B: array-like
+ Bicoherence between s1 and s2.
+
+ References
+ ----------
+ .. [1] http://wiki.fusenet.eu/wiki/Bicoherence
+ .. [2] https://stackoverflow.com/a/36725871
+
+ """
+ f1, _, S1 = compute_spectrogram(x1, sfreq, **kwargs)
+ _, _, S2 = compute_spectrogram(x2, sfreq, **kwargs)
+
+ # compute the bicoherence
+ ind = np.arange(f1.size // 2)
+ indsum = ind[:, None] + ind[None, :]
+
+ P1 = S1[..., ind, None]
+ P2 = S2[..., None, ind]
+ P12 = S1[..., indsum]
+
+ B = np.mean(P1 * P2 * np.conj(P12), axis=-3)
+
+ if norm is not None: # Bispectrum -> Bicoherence
+ B = norm_spectrum(B, P1, P2, P12, time_axis=-3)
+
+ return f1[ind], B
+
+
+def polycoherence_0d(X, sfreq, freqs, norm=2, synthetic=None, **kwargs):
+ """Polycoherence between freqs and sum of freqs.
+
+ Parameters
+ ----------
+ X: ndarray, shape=(n_channels, n_samples)
+ Input data.
+ sfreq: float
+ Sampling rate.
+ freqs: list[float]
+ Fixed frequencies.
+ norm: int | None
+ Norm (default=2).
+ synthetic: tuple(float, float, float)
+ Used for synthetic signal for some frequencies (freq, amplitude,
+ phase), freq must coincide with the first fixed frequency.
+ **kwargs: dict
+ Additional parameters passed to scipy.signal.spectrogram. Important
+ parameters are nperseg, noverlap, nfft.
+
+ Returns
+ -------
+ B: ndarray, shape=(n_channels,)
+ Polycoherence
+
+ """
+ assert isinstance(freqs, Iterable), "freqs must be a list"
+ freq, t, S = compute_spectrogram(X, sfreq, **kwargs)
+
+ ind = _freq_ind(freq, freqs)
+ indsum = _freq_ind(freq, np.sum(freqs))
+
+ Pi = _product_other_freqs(S, ind, synthetic, t)
+ Psum = S[..., indsum]
+
+ B = np.mean(Pi * np.conj(Psum), axis=-1)
+
+ if norm is not None:
+ # Bispectrum -> Bicoherence
+ B = norm_spectrum(B, Pi, 1., Psum, time_axis=-1)
+
+ return B
+
+
+def polycoherence_1d(X, sfreq, f2, norm=2, synthetic=None, **kwargs):
+ """1D polycoherence as a function of f1 and at least one fixed frequency f2.
+
+ Parameters
+ ----------
+ X: ndarray
+ Input data.
+ sfreq: float
+ Sampling rate
+ f2: list[float]
+ Fixed frequencies.
+ norm: int | None
+ Norm (default=2).
+ synthetic: tuple(float, float, float)
+ Used for synthetic signal for some frequencies (freq, amplitude,
+ phase), freq must coincide with the first fixed frequency.
+ **kwargs:
+ Additional parameters passed to scipy.signal.spectrogram. Important
+ parameters are `nperseg`, `noverlap`, `nfft`.
+
+ Returns
+ -------
+ freq: ndarray, shape=(n_freqs_f1,)
+ Frequencies
+ B: ndarray, shape=(n_channels, n_freqs_f1)
+ 1D polycoherence.
+
+ """
+ assert isinstance(f2, Iterable), "f2 must be a list"
+
+ f1, t, S = compute_spectrogram(X, sfreq, **kwargs, axis=-1)
+ S = np.require(S, "complex64")
+ ind2 = _freq_ind(f1, f2)
+ ind1 = np.arange(len(f1) - sum(ind2))
+ indsum = ind1 + sum(ind2)
+
+ P1 = S[..., ind1]
+ Pother = _product_other_freqs(S, ind2, synthetic, t)[..., None]
+ Psum = S[..., indsum]
+
+ B = np.nanmean(P1 * Pother * np.conj(Psum), axis=-2)
+
+ if norm is not None:
+ B = norm_spectrum(B, P1, Pother, Psum, time_axis=-2)
+
+ return f1[ind1], B
+
+
+def polycoherence_1d_sum(X, sfreq, fsum, *ofreqs, norm=2, synthetic=None, **kwargs):
+ """1D polycoherence with fixed frequency sum fsum as a function of f1.
+
+ Returns polycoherence for frequencies ranging from 0 up to fsum.
+
+ Parameters
+ ----------
+ X: ndarray
+ Input data.
+ sfreq: float
+ Sampling rate.
+ fsum : float
+ Fixed frequency sum.
+ ofreqs: list[float]
+ Fixed frequencies.
+ norm: int or None
+ If 2 - return polycoherence, n0 = n1 = n2 = 2 (default)
+ synthetic: tuple(float, float, float) | None
+ Used for synthetic signal for some frequencies (freq, amplitude,
+ phase), freq must coincide with the first fixed frequency.
+
+ Returns
+ -------
+ freq: ndarray, shape=(n_freqs,)
+ Frequencies.
+ B: ndarray, shape=(n_channels, n_freqs)
+ Polycoherence for f1+f2=fsum.
+
+ """
+ freq, t, S = compute_spectrogram(X, sfreq, **kwargs)
+
+ indsum = _freq_ind(freq, fsum)
+ ind1 = np.arange(np.searchsorted(freq, fsum - np.sum(ofreqs)))
+ ind3 = _freq_ind(freq, ofreqs)
+ ind2 = indsum - ind1 - sum(ind3)
+
+ P1 = S[..., ind1]
+ P2 = S[..., ind2]
+ Pother = _product_other_freqs(S, ind3, synthetic, t)[..., None]
+ Psum = S[..., [indsum]]
+
+ B = np.mean(P1 * P2 * Pother * np.conj(Psum), axis=-2)
+
+ if norm is not None:
+ B = norm_spectrum(B, P1, P2 * Pother, Psum, time_axis=-2)
+
+ return freq[ind1], B
+
+
+def polycoherence_2d(X, sfreq, ofreqs=None, norm=2, flim1=None, flim2=None,
+ synthetic=None, **kwargs):
+ """2D polycoherence between freqs and their sum as a function of f1 and f2.
+
+ 2D bicoherence is the most common form, where one looks at a
+ two-dimensional representation of the phase coupling between different
+ frequencies. It is a function of two frequency variables.
+
+ 2D polycoherence as a function of f1 and f2, ofreqs are additional fixed
+ frequencies.
+
+ Parameters
+ ----------
+ X: ndarray
+ Input data.
+ sfreq: float
+ Sampling rate.
+ ofreqs: list[float]
+ Fixed frequencies.
+ norm: int or None
+ If 2 - return polycoherence (default), else return polyspectrum.
+ flim1: tuple | None
+ Frequency limits for f1. If None, it is set to (0, nyquist / 2)
+ flim2: tuple | None
+ Frequency limits for f2.
+
+ Returns
+ -------
+ freq1: ndarray, shape=(n_freqs_f1,)
+ Frequencies for f1.
+ freq2: ndarray, shape=(n_freqs_f2,)
+ Frequencies for f2.
+ B: ndarray, shape=([n_chans, ]n_freqs_f1, n_freqs_f2)
+ Polycoherence.
+
+ """
+ freq, t, S = compute_spectrogram(X, sfreq, **kwargs)
+
+ if ofreqs is None:
+ ofreqs = []
+ if flim1 is None:
+ flim1 = (0, (np.max(freq) - np.sum(ofreqs)) / 2)
+ if flim2 is None:
+ flim2 = (0, (np.max(freq) - np.sum(ofreqs)) / 2)
+
+ # indices ranges matching flim1 and flim2
+ ind1 = np.arange(*np.searchsorted(freq, flim1))
+ ind2 = np.arange(*np.searchsorted(freq, flim2))
+ ind3 = _freq_ind(freq, ofreqs)
+ indsum = ind1[:, None] + ind2[None, :] + sum(ind3)
+
+ Pother = _product_other_freqs(S, ind3, synthetic, t)[..., None, None]
+
+ P1 = S[..., ind1, None]
+ P2 = S[..., None, ind2] * Pother
+ P12 = S[..., indsum]
+
+ # Average over time to get the bispectrum
+ B = np.nanmean(P1 * P2 * np.conj(P12), axis=-3)
+
+ if norm is not None: # Bispectrum -> Bicoherence
+ B = norm_spectrum(B, P1, P2, P12, time_axis=-3)
+
+ return freq[ind1], freq[ind2], B
+
+
+def compute_spectrogram(X, sfreq, **kwargs):
+ """Compute the complex spectrogram of X.
+
+ Simple wrapper around scipy.signal.spectrogram.
+
+ Returns
+ -------
+ f: ndarray, shape=(n_freqs,)
+ Positive frequencies.
+ t: ndarray, shape=(n_timesteps,)
+ Time axis.
+ S: ndarray, shape=(n_chans, n_timesteps, n_freqs)
+ Complex spectrogram (one-sided).
+ """
+ N = X.shape[-1]
+ kwargs.setdefault("nperseg", N // 2)
+ kwargs.setdefault("noverlap", N // 5)
+ kwargs.setdefault("nfft", next_fast_len(N))
+
+ f, t, S = spectrogram(X, fs=sfreq, mode="complex", return_onesided=False, **kwargs)
+ S = S[:len(f) // 2] # only positive frequencies
+ f = f[:len(f) // 2] # only positive frequencies
+ S = np.require(S, "complex64")
+ S = np.swapaxes(S, -1, -2) # transpose (f, t) -> (t, f)
+ return f, t, S
+
+
+def norm_spectrum(spec, P1, P2, P12, time_axis=-2):
+ """Compute bicoherence from bispectrum.
+
+ For formula see [1]_.
+
+ Parameters
+ ----------
+ spec: ndarray, shape=(n_chans, n_freqs)
+ Polyspectrum.
+ P1: array-like, shape=(n_chans, n_times, n_freqs_f1)
+ Spectrum evaluated at f1.
+ P2: array-like, shape=(n_chans, n_times, n_freqs_f2)
+ Spectrum evaluated at f2.
+ P12: array-like, shape=(n_chans, n_times, n_freqs)
+ Spectrum evaluated at f1 + f2.
+ time_axis: int
+ Time axis.
+
+ Returns
+ -------
+ coh: ndarray, shape=(n_chans,)
+ Polycoherence.
+
+ References
+ ----------
+ .. [1] https://en.wikipedia.org/wiki/Bicoherence
+
+ """
+ coh = np.abs(spec) ** 2
+ norm = np.nanmean(np.abs(P1 * P2) ** 2, axis=time_axis)
+ norm *= np.nanmean(np.abs(np.conj(P12)) ** 2, axis=time_axis)
+ coh /= norm
+ coh **= 0.5
+ return coh
+
+
+def plot_polycoherence(freq1, freq2, bicoh, ax=None):
+ """Plot polycoherence."""
+ df1 = freq1[1] - freq1[0] # resolution
+ df2 = freq2[1] - freq2[0]
+ freq1 = np.append(freq1, freq1[-1] + df1) - 0.5 * df1
+ freq2 = np.append(freq2, freq2[-1] + df2) - 0.5 * df2
+
+ if ax is None:
+ f, ax = plt.subplots()
+
+ ax.pcolormesh(freq2, freq1, np.abs(bicoh))
+ ax.set_xlabel("freq (Hz)")
+ ax.set_ylabel("freq (Hz)")
+ # ax.colorbar()
+ return ax
+
+def plot_polycoherence_1d(freq, coh):
+ """Plot polycoherence for fixed frequencies."""
+ plt.figure()
+ plt.plot(freq, coh)
+ plt.xlabel("freq (Hz)")
+
+
+def plot_signal(t, signal, ax=None):
+ """Plot signal and spectrum."""
+ if ax is None:
+ f, ax = plt.subplots(2, 1)
+
+ ax[0].plot(t, signal)
+ ax[0].set_xlabel("time (s)")
+
+ ndata = len(signal)
+ nfft = next_fast_len(ndata)
+ freq = rfftfreq(nfft, t[1] - t[0])
+ spec = rfft(signal, nfft) * 2 / ndata
+ ax[1].plot(freq, np.abs(spec))
+ ax[1].set_xlabel("freq (Hz)")
+ return ax
+
+def _freq_ind(freq, f0):
+ """Find the index of the frequency closest to f0."""
+ if isinstance(f0, Iterable):
+ return [np.argmin(np.abs(freq - f)) for f in f0]
+ else:
+ return np.argmin(np.abs(freq - f0))
+
+
+def _product_other_freqs(spec, indices, synthetic=None, t=None):
+ """Product of all frequencies."""
+ if synthetic is None:
+ synthetic = ()
+
+ p1 = synthetic_signal(t, synthetic)
+ p2 = np.prod(spec[..., indices[len(synthetic):]], axis=-1)
+ return p1 * p2
+
+
+def synthetic_signal(t, synthetic):
+ """Create a synthetic complex signal spectrum.
+
+ Parameters
+ ----------
+ t: array-like
+ Time.
+ synthetic: list[tuple(float, float, float)]
+ List of tuples with (freq, amplitude, phase).
+
+ Returns
+ -------
+ Complex signal.
+
+ """
+ return np.prod([amp * np.exp(2j * np.pi * freq * t + phase)
+ for (freq, amp, phase) in synthetic], axis=0)
\ No newline at end of file
diff --git a/meegkit/utils/covariances.py b/meegkit/utils/covariances.py
index b4b8b744..36b107fa 100644
--- a/meegkit/utils/covariances.py
+++ b/meegkit/utils/covariances.py
@@ -33,10 +33,10 @@ def block_covariance(data, window=128, overlap=0.5, padding=True, estimator="cov
Block covariance.
"""
- from pyriemann.utils.covariance import _check_est
+ from pyriemann.utils.covariance import check_function, cov_est_functions
assert 0 <= overlap < 1, "overlap must be < 1"
- est = _check_est(estimator)
+ est = check_function(estimator, cov_est_functions)
cov = []
n_chans, n_samples = data.shape
if padding: # pad data with zeros
@@ -212,18 +212,18 @@ def tscov(X, shifts=None, weights=None, assume_centered=True):
shifts, n_shifts = _check_shifts(shifts)
if not assume_centered:
- X = X - X.mean(0, keepdims=1)
+ X = X - X.mean(0, keepdims=True)
if weights.any(): # weights
X = np.einsum("ijk,ilk->ijk", X, weights) # element-wise mult
- tw = np.sum(weights[:])
+ tw = np.sum(weights[:]) - 1
else: # no weights
N = 0
if len(shifts[shifts < 0]):
N -= np.min(shifts)
if len(shifts[shifts >= 0]):
N += np.max(shifts)
- tw = (n_chans * n_shifts - N) * n_trials
+ tw = n_trials * (n_times - N - 1)
C = np.zeros((n_chans * n_shifts, n_chans * n_shifts))
for trial in range(n_trials):
diff --git a/meegkit/utils/denoise.py b/meegkit/utils/denoise.py
index 4317eabc..c896a998 100644
--- a/meegkit/utils/denoise.py
+++ b/meegkit/utils/denoise.py
@@ -75,7 +75,7 @@ def mean_over_trials(X, weights=None):
if not weights.any():
y = np.mean(X, 2)
- tw = np.ones((n_samples, n_chans, 1)) * n_trials
+ tw = np.ones((n_samples, n_chans, 1))
else:
m, n, o = theshapeof(weights)
if m != n_samples:
diff --git a/meegkit/utils/matrix.py b/meegkit/utils/matrix.py
index 4c4c733e..5b73746b 100644
--- a/meegkit/utils/matrix.py
+++ b/meegkit/utils/matrix.py
@@ -637,7 +637,7 @@ def _check_shifts(shifts, allow_floats=False):
"""Check shifts."""
types = (int, np.int_)
if allow_floats:
- types += (float, np.float_)
+ types += (float, np.float64)
if not isinstance(shifts, (np.ndarray, list, type(None)) + types):
raise AttributeError("shifts should be a list, an array or an int")
if isinstance(shifts, (list, ) + types):
diff --git a/meegkit/utils/sig.py b/meegkit/utils/sig.py
index 82473f8d..4451407d 100644
--- a/meegkit/utils/sig.py
+++ b/meegkit/utils/sig.py
@@ -276,18 +276,17 @@ def spectral_envelope(x, sfreq, lowpass=32):
return y[len(a):-len(b)]
-def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, show=False):
+def gaussfilt(data, sfreq, f, fwhm, n_harm=1, shift=0, return_empvals=False, show=False):
"""Narrow-band filter via frequency-domain Gaussian.
- Empirical frequency and FWHM depend on the sampling rate and the
- number of time points, and may thus be slightly different from
- the requested values.
+ Empirical frequency and FWHM depend on the sampling rate and the number of
+ time points, and may thus be slightly different from the requested values.
Parameters
----------
data : ndarray
EEG data, shape=(n_samples, n_channels[, ...])
- srate : int
+ sfreq : int
Sampling rate in Hz.
f : float
Break frequency of filter.
@@ -318,7 +317,7 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, sho
assert (fwhm >= 0), "FWHM must be greater than 0"
# frequencies
- hz = np.fft.fftfreq(data.shape[0], 1. / srate)
+ hz = np.fft.fftfreq(data.shape[0], 1. / sfreq)
empVals = np.zeros((2,))
# compute empirical frequency and standard deviation
diff --git a/meegkit/utils/testing.py b/meegkit/utils/testing.py
index e715028e..86f7ad97 100644
--- a/meegkit/utils/testing.py
+++ b/meegkit/utils/testing.py
@@ -6,7 +6,7 @@
def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20,
- n_bad_chans=1, SNR=.1, fline=1, t0=None, show=False):
+ n_bad_chans=1, SNR=.1, fline=0.1, t0=None, show=False):
"""Create synthetic data.
Parameters
@@ -24,7 +24,7 @@ def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20,
t0 : int
Onset sample of artifact.
fline : float
- Normalized frequency of artifact (freq/samplerate), (default=1).
+ Normalized frequency of artifact (freq/samplerate), (default=0.1).
Returns
-------
diff --git a/pyproject.toml b/pyproject.toml
index 0301fb52..df868ce2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,18 +66,18 @@ write-changes = false
# Linter configuration #
##################################
[tool.ruff]
-select = ["D", "E", "F", "B", "Q", "NPY", "I", "ICN", "UP"]
+lint.select = ["D", "E", "F", "B", "Q", "NPY", "I", "ICN", "UP"]
line-length = 90
target-version = "py310"
-ignore-init-module-imports = true
-ignore = ["E731", "B006", "B028", "UP038", "D100", "D105", "D212"]
+lint.ignore-init-module-imports = true
+lint.ignore = ["E731", "B006", "B028", "UP038", "D100", "D105", "D212"]
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"test_*.py" = ["D101", "D102", "D103", "D"]
"example_*.py" = ["D205", "D400", "D212"]
-[tool.ruff.pydocstyle]
+[tool.ruff.lint.pydocstyle]
convention = "numpy"
##################################
diff --git a/tests/test_coherence.py b/tests/test_coherence.py
new file mode 100644
index 00000000..5ad2ee91
--- /dev/null
+++ b/tests/test_coherence.py
@@ -0,0 +1,152 @@
+"""Test bicoherence functions."""
+import matplotlib.pyplot as plt
+import numpy as np
+import pytest
+from scipy.fftpack import next_fast_len
+
+from meegkit.utils.coherence import (
+ cross_coherence,
+ plot_polycoherence,
+ plot_polycoherence_1d,
+ plot_signal,
+ polycoherence_0d,
+ polycoherence_1d,
+ polycoherence_1d_sum,
+ polycoherence_2d,
+)
+
+
+@pytest.mark.parametrize("norm", [None, 2])
+def test_coherence(norm, show=True):
+ rng = np.random.default_rng(54326)
+
+ # create signal with 5Hz and 8Hz components and noise, as well as a
+ # 5Hz-8Hz interaction
+ N = 10001
+ kw = dict(nperseg=N // 4, noverlap=N // 20, nfft=next_fast_len(N // 2))
+ t = np.linspace(0, 100, N)
+ fs = 1 / (t[1] - t[0])
+ s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz
+ s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz
+ noise = 5 * rng.normal(0, 1, N)
+ signal = s1 + s2 + noise + 0.5 * s1 * s2
+
+ # bicoherence
+ # ---------------------------------------------r----------------------------
+
+ # 0D local bicoherence (fixed frequencies)
+ p5_8 = polycoherence_0d(signal, fs, [5, 8], **kw)
+ p5_6 = polycoherence_0d(signal, fs, [5, 6], **kw)
+ p5_6_8 = polycoherence_0d(signal, fs, [5, 6, 8], **kw)
+ if norm is not None:
+ print(f"bicoherence for f1=5Hz, f2=8Hz: {p5_8:.2f}")
+ print(f"bicoherence for f1=5Hz, f2=6Hz: {p5_6:.2f}")
+ print(f"bicoherence for f1=5Hz, f2=6Hz, f3=7Hz: {p5_6_8:.2f}")
+ assert p5_8 > 0.85 > p5_6 > p5_6_8 # 5Hz and 7Hz are coherent, 5Hz and 6Hz not
+
+ # 1D bicoherence (fixed f2)
+ freqs, coh1d = polycoherence_1d(signal, fs, [5], **kw)
+
+ # 1D bicoherence with sum (fixed f1+f2)
+ freqs1dsum, coh1dsum = polycoherence_1d_sum(signal, fs, 13, **kw)
+
+ # 2D bicoherence, span all frequencies
+ # assert peaks at intersection of 5Hz and 8Hz, and 8-5=3Hz
+ freqs1, freqs2, coh2d = polycoherence_2d(signal, fs, **kw)
+
+ if norm is not None:
+ assert np.max(coh2d) > 0.85
+ assert np.abs(coh2d[freqs1 == 5, freqs2 == 8]) > 0.85
+ assert np.abs(coh2d[freqs1 == 5, freqs2 == 3]) > 0.85
+
+
+ if show:
+ # Plot signal
+ plot_signal(t, signal)
+ plt.suptitle("signal and spectrum for bicoherence tests")
+
+ # Plot bicoherence
+ plot_polycoherence(freqs1, freqs2, coh2d)
+ plt.suptitle("bicoherence")
+
+ # Plot bicoherence 1D
+ plot_polycoherence_1d(freqs, coh1d)
+ plt.suptitle("bicoherence for f2=5Hz (column, expected 3Hz, 8Hz)")
+
+ # Plot bicoherence for f1+f2=13Hz
+ plot_polycoherence_1d(freqs1dsum, coh1dsum)
+ plt.suptitle("bicoherence for f1+f2=13Hz (diagonal, expected 5Hz, 8Hz)")
+ plt.show()
+
+
+ # bicoherence with synthetic signal
+ # -------------------------------------------------------------------------
+ s3 = 4 * np.cos(2 * np.pi * 1 * t + 0.1)
+ s5 = 0.4 * np.cos(2 * np.pi * 0.2 * t + 1)
+ signal = s2 + s3 + s5 - 0.5 * s2 * s3 * s5 + noise
+
+ synthetic = ((0.2, 10, 1), )
+ p02_1_8 = polycoherence_0d(signal, fs, [0.2, 1, 8], synthetic=None,
+ norm=norm, **kw)
+ p02_1_8s = polycoherence_0d(signal, fs, [0.2, 1, 8], synthetic=synthetic,
+ norm=norm, **kw)
+ if norm is not None:
+ print(f"coherence for f1=0.02Hz, f2=1Hz, f3=7Hz: {p02_1_8:.2f}")
+ print(f"coherence for f1=0.02Hz (synthetic), f2=1Hz, f3=7Hz: {p02_1_8s:.2f}")
+ assert p02_1_8s > 0.9 > p02_1_8
+
+ result = polycoherence_2d(signal, fs, [0.02], synthetic=synthetic, norm=norm, **kw)
+
+ if show:
+ plot_signal(t, signal)
+ plt.suptitle("signal and spectrum for tricoherence with synthetic signals")
+ plt.tight_layout()
+
+ plot_polycoherence(*result)
+ plt.suptitle("tricoherence with f3=0.2Hz (synthetic)")
+ plt.show()
+
+ # cross-coherence
+ # -------------------------------------------------------------------------
+ s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz
+ s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz
+ signal1 = s1 + 4 * rng.normal(0, 1, N)
+ signal2 = s2 + s1 + 5 * rng.normal(0, 1, N)
+ f, Cxy = cross_coherence(signal1, signal2, fs, norm=None, **kw)
+
+ if show:
+ plot_polycoherence(f, f, Cxy)
+ plt.suptitle("cross-coherence between 5Hz and 5+8Hz signals")
+ plt.tight_layout()
+ plt.show()
+
+@pytest.mark.parametrize("shape", [(1001,), (3, 1001), (17, 3, 1001)])
+def test_coherence_shapes(shape):
+ """Test coherence functions with 1D, 2D or 3D input data."""
+ rng = np.random.default_rng(54326)
+
+ # create signal with 5Hz and 8Hz components and noise, as well as a
+ # 5Hz-8Hz interaction
+ N = shape[-1]
+ t = np.linspace(0, 100, N)
+ fs = 1 / (t[1] - t[0])
+ s1 = np.cos(2 * np.pi * 5 * t + 0.3) # 5Hz
+ s2 = 3 * np.cos(2 * np.pi * 8 * t + 0.5) # 8Hz
+ noise = 5 * rng.normal(0, 1, shape)
+ signal = np.broadcast_to(s1 + s2, shape) + noise
+
+ assert signal.shape == shape
+
+ B = polycoherence_0d(signal, fs, [5, 8])
+ assert B.shape == shape[:-1]
+ f, B = polycoherence_1d(signal, fs, [8])
+ assert B.shape == shape[:-1] + f.shape
+ f, B = polycoherence_1d_sum(signal, fs, 13)
+ assert B.shape == shape[:-1] + f.shape
+ f1, f2, B = polycoherence_2d(signal, fs)
+ assert B.shape == shape[:-1] + f1.shape + f2.shape
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
+ # test_coherence(2, False)
\ No newline at end of file
diff --git a/tests/test_cov.py b/tests/test_cov.py
index ce69b470..ff24e496 100644
--- a/tests/test_cov.py
+++ b/tests/test_cov.py
@@ -12,7 +12,7 @@ def test_tscov():
# Compare 0-lag case with numpy.cov()
c1, n1 = tscov(x, [0])
- c2 = np.cov(x, bias=True)
+ c2 = np.cov(x, bias=False)
assert_almost_equal(c1 / n1, c2)
# Compare 0-lag case with numpy.cov()
@@ -89,6 +89,7 @@ def test_convmtx():
)
if __name__ == "__main__":
- # import pytest
- # pytest.main([__file__])
- test_convmtx()
+ import pytest
+ pytest.main([__file__])
+ # test_tscov()
+ # test_convmtx()
diff --git a/tests/test_dss.py b/tests/test_dss.py
index 5b354a66..9abdeacc 100644
--- a/tests/test_dss.py
+++ b/tests/test_dss.py
@@ -47,7 +47,7 @@ def test_dss1(show=True):
n_samples = 300
data, source = create_line_data(n_samples=n_samples, fline=.01)
- todss, _, pwr0, pwr1 = dss.dss1(data, weights=None, )
+ todss, _, pwr0, pwr1 = dss.dss1(data, weights=None)
z = fold(np.dot(unfold(data), todss), epoch_size=n_samples)
best_comp = np.mean(z[:, 0, :], -1)
@@ -102,6 +102,7 @@ def _plot(x):
ax[1].set_title("after")
plt.show()
+
# 2D case, n_outputs == 1
out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep)
_plot(out)
@@ -122,7 +123,7 @@ def _plot(x):
artifact = artifact ** 3
s = x + 10 * artifact
out, _ = dss.dss_line(s, fline, sr, nremove=1)
-
+ plt.close("all")
def test_dss_line_iter():
"""Test line noise removal."""
@@ -171,6 +172,7 @@ def _plot(before, after):
x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=2,
noise_dim=10, SNR=2, fline=fline / sr)
out, _ = dss.dss_line_iter(x, fline, sr, show=False)
+ plt.close("all")
def profile_dss_line(nkeep):
diff --git a/tests/test_filters.py b/tests/test_filters.py
index 93f42341..8531b90c 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -43,15 +43,15 @@ def test_noisy_signal(show=True):
ax[0, 0].set_title("Signal and its Hilbert phase")
ax[1, 0].plot(time, gtp, lw=.75, label="Ground truth")
- ax[1, 0].plot(time, nrp, lw=.75, label=r"$\phi_N$")
+ ax[1, 0].plot(time, nrp, lw=.75, label=r"$\\phi_N$")
ax[1, 0].set_ylabel(r"$\phi_N$")
ax[1, 0].set_ylim([-np.pi, np.pi])
ax[1, 0].set_title("Nonresonant oscillator")
ax[2, 0].plot(time, gtp, lw=.75, label="Ground truth")
- ax[2, 0].plot(time, rp, lw=.75, label=r"$\phi_N$")
+ ax[2, 0].plot(time, rp, lw=.75, label=r"$\\phi_N$")
ax[2, 0].set_ylim([-np.pi, np.pi])
- ax[2, 0].set_ylabel("$\phi_H - \phi_R$")
+ ax[2, 0].set_ylabel(r"$\phi_H - \phi_R$")
ax[2, 0].set_xlabel("Time")
ax[2, 0].set_title("Resonant oscillator")
@@ -132,7 +132,7 @@ def test_all_alg(show=False):
ax[3, 0].plot(time, r_phi_dif, lw=.75)
ax[3, 0].axhline(0, color="k", ls=":", zorder=-1)
ax[3, 0].set_ylim([-np.pi, np.pi])
- ax[3, 0].set_ylabel("$\phi_H - \phi_R$")
+ ax[3, 0].set_ylabel(r"$\phi_H - \phi_R$")
ax[3, 0].set_xlabel("Time")
ax[3, 0].set_title("Resonant oscillator")
@@ -216,5 +216,5 @@ def generate_noisy_signal(npt=40000, fs=100, noise=0.1):
if __name__ == "__main__":
# Run the model_data_all_alg function
- test_all_alg()
- # test_noisy_signal()
\ No newline at end of file
+ test_all_alg(True)
+ # test_noisy_signal(True)
\ No newline at end of file
diff --git a/tests/test_signal.py b/tests/test_signal.py
index 9e9688ec..c0e654c3 100644
--- a/tests/test_signal.py
+++ b/tests/test_signal.py
@@ -1,7 +1,9 @@
"""Test signal utils."""
+import matplotlib.pyplot as plt
import numpy as np
-from scipy.signal import butter, freqz, lfilter
+from scipy.signal import butter, freqz, hilbert, lfilter
+from meegkit.phase import ECHT
from meegkit.utils.sig import stmcb, teager_kaiser
rng = np.random.default_rng(9)
@@ -51,6 +53,92 @@ def test_stcmb(show=True):
np.testing.assert_allclose(h, hh, rtol=2) # equal to 2%
+
+def test_echt(show=False):
+ """Test Endpoint-corrected Hilbert transform (ECHT) phase estimation."""
+ rng = np.random.default_rng(38872)
+
+ # Build data
+ # -------------------------------------------------------------------------
+ # First, we generate a multi-component signal with amplitude and phase
+ # modulations, as described in the paper [1]_.
+ f0 = 2
+ filt_BW = f0 / 2
+ N = 500
+ sfreq = 200
+ time = np.linspace(0, N / sfreq, N)
+ X = np.cos(2 * np.pi * f0 * time - np.pi / 4)
+ phase_true = np.angle(hilbert(X))
+ X += rng.normal(0, 0.5, N) # Add noise
+
+ # Compute phase and amplitude
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # We compute the Hilbert phase, as well as the phase obtained with the ECHT
+ # filter.
+ # phase_hilbert = np.angle(hilbert(X)) # Hilbert phase
+
+ # Compute ECHT-filtered signal
+ l_freq = f0 - filt_BW / 2
+ h_freq = f0 + filt_BW / 2
+ echt = ECHT(l_freq, h_freq, sfreq)
+
+ Xf = echt.fit_transform(X)
+ phase_echt = np.angle(Xf).squeeze()
+ # phase_true = np.roll(phase_true, 1)
+ if show:
+ fig, ax = plt.subplots(3, 1, figsize=(8, 5))
+ ax[0].plot(time, X)
+ ax[0].set_xlabel("Time (s)")
+ ax[0].set_title("Test signal")
+ ax[0].set_ylabel("Amplitude")
+
+ ax[1].plot(time, phase_true, label="True phase", ls=":")
+ ax[1].plot(time, phase_echt, label="ECHT phase", lw=.5, alpha=.8)
+ ax[1].set_title("Phase")
+ ax[1].set_ylabel("Amplitude")
+ ax[1].set_xlabel("Time (s)")
+ ax[1].legend(loc="upper right", fontsize="small")
+
+ ax[2].plot(time, np.unwrap(phase_true - phase_echt.squeeze()),
+ label="Phase error")
+ ax[2].set_title("Phase error")
+ ax[2].set_ylabel("Error")
+
+ plt.tight_layout()
+ plt.show()
+
+ mae = (np.abs(np.unwrap(phase_true - phase_echt.squeeze())) > np.pi / 6).sum() / N
+ assert mae < 0.1, mae
+
+def test_echt_nd():
+ """Test ECHT with ndim > 1."""
+ rng = np.random.default_rng(38872)
+
+ # Build data
+ # -------------------------------------------------------------------------
+ # First, we generate a multi-component signal with amplitude and phase
+ # modulations, as described in the paper [1]_.
+ f0 = 2
+ filt_BW = f0 / 2
+ N = 500
+ sfreq = 200
+ time = np.linspace(0, N / sfreq, N)
+ X = np.cos(2 * np.pi * f0 * time - np.pi / 4)[:, None]
+ X = X + rng.normal(0, 0.5, (N, 2)) # Add noise
+
+ l_freq = f0 - filt_BW / 2
+ h_freq = f0 + filt_BW / 2
+ echt = ECHT(l_freq, h_freq, sfreq)
+
+ Xf = echt.fit_transform(X)
+ phase_echt = np.angle(Xf)
+
+ assert phase_echt.shape == (N, 2)
+
+
if __name__ == "__main__":
- test_teager_kaiser()
- test_stcmb()
+ import pytest
+ pytest.main([__file__])
+ # test_teager_kaiser()
+ # test_stcmb()
+ # test_echt()
\ No newline at end of file