Skip to content

Commit

Permalink
fix offset and simplify template handling (#1)
Browse files Browse the repository at this point in the history
* fix offset and simplify template handling

* bump version
  • Loading branch information
lucasplagwitz authored Sep 12, 2024
1 parent 5037233 commit 48f8185
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 265 deletions.
27 changes: 17 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,28 @@ normalizer = rlign.Rlign(scale_method='hrc')
# Input shape has to be (samples, channels, len)
ecg_aligned = normalizer.transform(ecg)

# You can update some configurations later on
template_ = rlign.Template(template_bpm=80)
normalizer.update_configuration(template=template_)
# You can set different configuration like median_beat-averaging or the template_bpm
normalizer = rlign.Rlign(scale_method='hrc', median_beat=True, template_bpm=80)

ecg_aligned_80hz = normalizer.transform(ecg)
```

### Configurations

* `sampling_rate`: Defines the sampling rate for all ECG recordings.
* `sampling_rate`: Defines the sampling rate for all ECG recordings and the template. Default is set to 500.

* `template`: A template ECG created with `create_template()` method. This template is
used as a reference for aligning R-peaks in the ECG signals.
* `seconds_len`: Determines the duration of all ECG recordings and the template in seconds.

* `select_lead`: Specifies the lead (e.g., 'Lead II', 'Lead V1') for R-peak detection. Different leads can provide varying levels of
clarity for these features. Selection via channel numbers 0,1,... .
* `template_bpm`: The desired normalized BPM value for the template.
This parameter sets the heart rate around which the QRST pattern
is structured, thereby standardizing the R-peak positions according to a specific BPM.

* `offset`: The offset specifies the starting point for the first normalized QRS complex in the
template. In percentage of sampling_rate. Default is set to 0.5.

* `select_lead`: Specifies the lead (e.g., 'Lead II', 'Lead V1') for R-peak detection.
Different leads can provide varying levels of clarity for these features.
Selection via channel numbers 0,1,... .

* `num_workers`: Determines the number of CPU cores to be utilized for
parallel processing. Increasing this number can speed up computations
Expand All @@ -65,15 +71,16 @@ ecg_aligned_80hz = normalizer.transform(ecg)
* `scale_method`: Selects the scaling method from options like 'resampling'
or 'hrc'. This choice dictates the interval used for resampling
the ECG signal, which can impact the quality of the processed signal.
Default is 'hrc'.

* `remove_fails`: Determines the behavior when scaling is not possible. If
set to True, the problematic ECG is excluded from the dataset. If False,
the original, unscaled ECG signal is returned instead.
the original, unscaled ECG signal is returned instead. Default is False.

* `median_beat`: Calculates the median from a set of aligned beats
and returns a single, representative beat.

* `silent`: Disable all warnings.
* `silent`: Disable all warnings. Default True.

## Citation
Please use the following citation:
Expand Down
190 changes: 93 additions & 97 deletions examples/example.ipynb

Large diffs are not rendered by default.

211 changes: 105 additions & 106 deletions examples/example_hz.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requires = [

[project]
name='rlign'
version='1.0.post1'
version='1.0.post2'
description='Beat normalization for ECG data.'
readme = "README.md"
requires-python = ">=3.11"
Expand Down
116 changes: 82 additions & 34 deletions rlign/rlign.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ class Rlign(BaseEstimator, TransformerMixin, auto_wrap_output_keys=None):
ECG, lead selection, and artifact correction.
Parameters:
sampling_rate: Defines the sampling rate for all ECG recordings.
sampling_rate: Defines the sampling rate for all ECG recordings and the template.
template: Path or identifier for the template ECG. This template is
used as a reference for identifying R-peaks in the ECG signals.
seconds_len: Determines the duration of all ECG recordings and the template in seconds.
template_bpm: The desired normalized BPM value for the template. This parameter sets the
heart rate around which the QRST pattern is structured, thereby standardizing the R-peak
positions according to a specific BPM.
offset: The offset specifies the starting point for the first normalized QRS complex in the
template. In percentage of sampling_rate. Default is set to 0.01.
select_lead: Specifies the lead (e.g., 'Lead II', 'Lead V1') for R-peak
and QRST point detection. Different leads can provide varying levels of
clarity for these features. Selection via channel numbers 0,1,... .
num_workers: Determines the number of CPU cores to be utilized for
parallel processing. Increasing this number can speed up computations
but requires more system resources.
but requires more system resources. Default is set to 4.
neurokit_method: Chooses the algorithm for R-peak detection from the
NeuroKit package. Different algorithms may offer varying performance
Expand All @@ -44,48 +50,57 @@ class Rlign(BaseEstimator, TransformerMixin, auto_wrap_output_keys=None):
scale_method: Selects the scaling method from options 'identity', 'linear'
or 'hrc'. This choice dictates the interval used for resampling
the ECG signal, which can impact the quality of the processed signal.
Default is 'hrc'.
remove_fails: Determines the behavior when scaling is not possible. If
set to True, the problematic ECG is excluded from the dataset. If False,
the original, unscaled ECG signal is returned instead.
the original, unscaled ECG signal is returned instead. Default is False.
median_beat: Calculates the median from a set of aligned beats and returns
a single, representative beat.
a single, representative beat. Default is False.
silent: Disable all warnings.
silent: Disable all warnings. Default True.
"""

__allowed = ("sampling_rate", "template", "resample_method",
"select_lead", "num_workers", "neurokit_method",
"correct_artifacts", "scale_method")
__allowed = ("sampling_rate", "resample_method", "select_lead",
"num_workers", "neurokit_method", "correct_artifacts",
"scale_method", "seconds_len", "template_bpm", "offset", "silent")

def __init__(
self,
sampling_rate: int = 500,
template: Template = None,
seconds_len: int = 10,
template_bpm: int = 60,
offset: float = .01,
select_lead: int = 1,
num_workers: int = 8, # only applies for multiprocessing
num_workers: int = 4,
neurokit_method: str = 'neurokit',
correct_artifacts: bool = True,
scale_method: str = 'hrc',
remove_fails: bool = False,
median_beat: bool = False,
silent: bool = True
):
self.sampling_rate = sampling_rate
if template:
self.template = template
else:
self.template = Template(
sampling_rate=sampling_rate
)

self._sampling_rate = sampling_rate
self._offset = offset
self._template_bpm = template_bpm
self._seconds_len = seconds_len

self.template = Template(
sampling_rate=self.sampling_rate,
offset=self.offset,
template_bpm=self.template_bpm,
seconds_len=self.seconds_len
)
self.select_lead = select_lead
self.num_workers = num_workers
self.neurokit_method = neurokit_method
self.correct_artifacts = correct_artifacts
self.remove_fails = remove_fails
self.fails = []
self.median_beat = median_beat
self.silent = silent

available_scale_methods = ['identity', 'linear', 'hrc']
if scale_method in available_scale_methods:
Expand All @@ -97,13 +112,10 @@ def __init__(
raise ValueError(f'No such scaling method, '
f'please use one of the following: {available_scale_methods}')

if silent:
if self.silent:
warnings.filterwarnings("ignore")

def update_configuration(
self,
**kwargs
):
def update_configuration(self, **kwargs):
"""
Enables modification of existing configuration settings as required.
Expand All @@ -116,6 +128,50 @@ def update_configuration(
assert (k in self.__class__.__allowed), f"Disallowed keyword passed: {k}"
setattr(self, k, kwargs[k])

@property
def offset(self):
return self._offset

@property
def template_bpm(self):
return self._template_bpm

@property
def seconds_len(self):
return self._seconds_len

@property
def sampling_rate(self):
return self._sampling_rate

@offset.setter
def offset(self, val):
self._offset = val
self._update_template()

@sampling_rate.setter
def sampling_rate(self, val):
self._sampling_rate = val
self._update_template()

@seconds_len.setter
def seconds_len(self, val):
self._seconds_len = val
self._update_template()

@template_bpm.setter
def template_bpm(self, val):
self._template_bpm = val
self._update_template()

def _update_template(self):
self.template = Template(
sampling_rate=self.sampling_rate,
offset=self.offset,
template_bpm=self.template_bpm,
seconds_len=self.seconds_len
)

def _normalize_interval(
self,
source_ecg: np.ndarray,
Expand Down Expand Up @@ -182,7 +238,6 @@ def _normalize_interval(
return source_ecg[:, :self.template.intervals[0]], 1

case "hrc":

template_rr_dist = template_starts[2] - template_starts[1]
dist_upper_template = int((self.template.bpm / 280+0.14) * template_rr_dist)
dist_lower_template = int((-self.template.bpm / 330 + 0.96) * template_rr_dist)
Expand Down Expand Up @@ -272,10 +327,7 @@ def fit(self, X: np.ndarray, y=None):
"""
return self

def _template_transform(
self,
ecg: np.ndarray,
):
def _template_transform(self, ecg: np.ndarray):
"""
Normalizes the ECG by recentering R-peaks at equally spaced intervals, as defined in the template.
The QRST-complexes are added, and the baselines are interpolated to match the new connections between complexes.
Expand Down Expand Up @@ -329,11 +381,7 @@ def _template_transform(
hr=hr
)

def transform(
self,
X: np.ndarray,
y=None
) -> np.ndarray:
def transform(self, X: np.ndarray, y=None) -> np.ndarray:
"""
Normalizes and returns the ECGs with multiprocessing.
Expand Down
11 changes: 7 additions & 4 deletions rlign/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Template(object):
def __init__(self,
seconds_len: float = 10,
sampling_rate: int = 500,
template_bpm: int = 40,
template_bpm: int = 60,
offset: float = 0.5):
"""
Creates a binary template pattern for an ECG, specifically designed for positioning the QRST interval.
Expand All @@ -38,21 +38,22 @@ def __init__(self,
positions according to a specific BPM.
offset: The offset specifies the starting point for the first normalized QRS complex in the
template.
template. In percentage of sampling_rate.
Note:
This class is integral to the ECG analysis process, providing a foundational template for
subsequent signal processing and interpretation.
"""

if not (0 <= offset < 1):
raise ValueError("Parameter offset must be between 0 and 1.")
offset = int(offset * sampling_rate)
total_len = seconds_len * sampling_rate
template_bpm_len = (template_bpm / 60) * seconds_len
self.bpm = template_bpm

# compute spacing between rpeaks
template_intervals = [int((total_len - offset) / template_bpm_len)] * int(template_bpm_len)
template_intervals = [int(total_len / template_bpm_len)] * int(template_bpm_len)

# get equally spaced r-peak positions
template_rpeaks = np.cumsum([offset, *template_intervals])
Expand Down Expand Up @@ -111,6 +112,8 @@ def find_rpeaks(
logging.warning(f'Failure in neurokit: {e}\n')
return None, None

if not len(rpeaks):
return None, None
return rpeaks


Expand Down
20 changes: 7 additions & 13 deletions test/test_rlign.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ def test_multi_processing(self):
print(self.X.shape)


normalizer_single_cpu = rlign.Rlign(num_workers=1, select_lead=0)
normalizer_multiple_cpu = rlign.Rlign(num_workers=8, select_lead=0)
template_ = rlign.Template(template_bpm=40)
normalizer_single_cpu.update_configuration(template=template_)
normalizer_multiple_cpu.update_configuration(template=template_)
normalizer_single_cpu = rlign.Rlign(num_workers=1, select_lead=0, template_bpm=40)
normalizer_multiple_cpu = rlign.Rlign(num_workers=8, select_lead=0, template_bpm=40)

start_time = time.time()
normalizer_single_cpu.transform(self.X)
Expand All @@ -41,21 +38,18 @@ def test_multi_processing(self):
self.assertLess(diff_multiple, diff_single)

def test_scale_method(self):
normalizer_hrc = rlign.Rlign(num_workers=1, select_lead=0, scale_method="hrc")
template_ = rlign.Template(template_bpm=40)
normalizer_hrc.update_configuration(template=template_)
normalizer_hrc = rlign.Rlign(num_workers=1, select_lead=0, scale_method="hrc", template_bpm=40)
X_trans = normalizer_hrc.transform(self.X[:10])
self.assertEquals(X_trans.shape, (10, 1, 5000))

with self.assertRaises(ValueError):
rlign.Rlign(num_workers=1, select_lead=0, scale_method="equal")

def test_zero(self):
normalizer_hrc = rlign.Rlign(num_workers=1, select_lead=0, scale_method="hrc", remove_fails=False)
template_ = rlign.Template(template_bpm=40)
normalizer_hrc.update_configuration(template=template_)
X_trans = normalizer_hrc.transform(np.zeros((1,12,5000)))
self.assertTrue(np.array_equal(X_trans, np.zeros((1,12,5000))))
normalizer_hrc = rlign.Rlign(num_workers=1, select_lead=0, scale_method="hrc",
remove_fails=False, template_bpm=40)
X_trans = normalizer_hrc.transform(np.zeros((1, 12, 5000)))
self.assertTrue(np.array_equal(X_trans, np.zeros((1, 12, 5000))))

a = np.concatenate([np.zeros((1, 1, 5000)), self.X[:10]], axis=0)
X_trans = normalizer_hrc.transform(a)
Expand Down
10 changes: 10 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ def test_find_rpeaks(self):
ret = find_rpeaks(0*self.ecg_500hz_10s, sampling_rate=500)
self.assertIsNone(ret[0])
self.assertIsNone(ret[1])

def test_interval_length(self):
offset = .5
sampling_rate = 500
template60 = Template(seconds_len=10, sampling_rate=sampling_rate, template_bpm=60, offset=offset)
self.assertEquals(np.min(template60.intervals), sampling_rate)
self.assertEquals(np.max(template60.intervals), sampling_rate)
self.assertEquals(np.min(template60.rpeaks), int(offset*sampling_rate))
self.assertEquals(np.min(np.diff(template60.rpeaks)), sampling_rate)
self.assertEquals(np.max(np.diff(template60.rpeaks)), sampling_rate)

0 comments on commit 48f8185

Please sign in to comment.