From fb6b8b841de8c0c7d3f4f4c05954448228b8cb62 Mon Sep 17 00:00:00 2001 From: "Sean M. Law" <7473521+seanlaw@users.noreply.github.com> Date: Sat, 6 Jul 2024 16:16:00 -0400 Subject: [PATCH] Fixed #772 Add Ray Support (#986) * Initial commit for adding ray support * Fixed missing import * Fixed typo * Refactored coverage reporting * Force codecov to use coverage.py report instead of generating its own * Specify coverage.xml in PWD * Added verbose flag * Minor change * Added new codecov token * Added temporary break in test converage * Minor change * Changed name of coverage file * Expand coverage to all tests, removed break * Removed comments in workflow * Removed codecov patch/project status * Minor change * Added ability to specify coverage.xml file * Split xml report and displaying report * Updated docstrings to include `ray` example * Fixed flake8 problem * Added ray docstring examples * Fixed typo * Reverted missing docstring * Minor change * Reverted docstring * Minor changes * Added ray.shutdown() * Changed how `step` is calculated to be more precise * Fixed missing paratheses --- .github/workflows/github-actions.yml | 8 +- .gitignore | 3 +- codecov.yml | 4 + conda.sh | 15 ++ min.py => min_versions.py | 0 ray_python_version.py | 17 ++ stumpy/aamp_stimp.py | 12 +- stumpy/aamped.py | 163 ++++++++++++++++++-- stumpy/core.py | 87 ++++++++++- stumpy/maamped.py | 175 +++++++++++++++++++-- stumpy/mpdist.py | 16 +- stumpy/mstumped.py | 196 ++++++++++++++++++++++- stumpy/ostinato.py | 17 +- stumpy/stimp.py | 10 ++ stumpy/stumped.py | 215 ++++++++++++++++++++++++-- test.sh | 107 +++++++++++-- tests/test_ray.py | 222 +++++++++++++++++++++++++++ 17 files changed, 1200 insertions(+), 67 deletions(-) create mode 100644 codecov.yml rename min.py => min_versions.py (100%) create mode 100755 ray_python_version.py create mode 100644 tests/test_ray.py diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index 555e9742b..69d4e35f7 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -152,8 +152,12 @@ jobs: - name: Run Coverage Tests run: ./test.sh coverage shell: bash - - name: Check Coverage Report - run: coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min.py,stumpy/cache.py + - name: Generate Coverage Report + run: ./test.sh report coverage.stumpy.xml shell: bash - name: Upload Coverage Tests Results uses: codecov/codecov-action@v4 + with: + file: ./coverage.stumpy.xml + verbose: true + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index aaa73b8ac..4ba903a59 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ LOG* PID .coverage* coverage.xml +stumpy.coverage.xml dask-worker-space stumpy.egg-info build @@ -20,4 +21,4 @@ docs/_build .mypy_cache .directory test.py -*.nbconvert.ipynb \ No newline at end of file +*.nbconvert.ipynb diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..35cde5cd5 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,4 @@ +coverage: + status: + project: off + patch: off diff --git a/conda.sh b/conda.sh index 29721f6bc..cf39c48f6 100755 --- a/conda.sh +++ b/conda.sh @@ -13,6 +13,9 @@ if [[ $# -gt 0 ]]; then if [ $1 == "min" ]; then install_mode="min" echo "Installing minimum dependencies with install_mode=\"min\"" + elif [[ $1 == "ray" ]]; then + install_mode="ray" + echo "Installing ray dependencies with install_mode=\"ray\"" elif [[ $1 == "numba" ]] && [[ "${arch_name}" != "arm64" ]]; then install_mode="numba" echo "Installing numba release candidate dependencies with install_mode=\"numba\"" @@ -57,6 +60,14 @@ generate_numba_environment_yaml() grep -Ev "numba|python" environment.yml > environment.numba.yml } +generate_ray_environment_yaml() +{ + # Limit max Python version and append pip install ray + echo "Generating \"environment.ray.yml\" File" + ray_python=`./ray_python_version.py` + sed "/ - python/ s/$/,<=$ray_python/" environment.yml | cat - <(echo $' - pip\n - pip:\n - ray>=2.23.0') > environment.ray.yml +} + fix_libopenblas() { if [ ! -f $CONDA_PREFIX/lib/libopenblas.dylib ]; then @@ -71,6 +82,7 @@ clean_up() echo "Cleaning Up" rm -rf "environment.min.yml" rm -rf "environment.numba.yml" + rm -rf "environment.ray.yml" } ########### @@ -92,6 +104,9 @@ fi if [[ $install_mode == "min" ]]; then generate_min_environment_yaml mamba env update --name $conda_env --file environment.min.yml || conda env update --name $conda_env --file environment.min.yml +elif [[ $install_mode == "ray" ]]; then + generate_ray_environment_yaml + mamba env update --name $conda_env --file environment.ray.yml || conda env update --name $conda_env --file environment.ray.yml elif [[ $install_mode == "numba" ]]; then echo "" echo "Installing python=$python_version" diff --git a/min.py b/min_versions.py similarity index 100% rename from min.py rename to min_versions.py diff --git a/ray_python_version.py b/ray_python_version.py new file mode 100755 index 000000000..f175443eb --- /dev/null +++ b/ray_python_version.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +import requests +from packaging.version import Version + +classifiers = ( + requests.get("https://pypi.org/pypi/ray/json").json().get("info").get("classifiers") +) + +versions = [] +for c in classifiers: + x = c.split() + if "Python" in x: + versions.append(x[-1]) + +versions.sort(key=Version) +print(versions[-1]) diff --git a/stumpy/aamp_stimp.py b/stumpy/aamp_stimp.py index 81313e7d4..cf15725a8 100644 --- a/stumpy/aamp_stimp.py +++ b/stumpy/aamp_stimp.py @@ -486,16 +486,15 @@ def __init__( class aamp_stimped(_aamp_stimp): """ - Compute the Pan Matrix Profile with a distributed dask cluster + Compute the Pan Matrix Profile with a `dask`/`ray` cluster This is based on the SKIMP algorithm. Parameters ---------- client : client - A Dask or Ray Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask or Ray Distributed - documentation. + A `dask`/`ray` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `dask`/`ray` documentation. T : numpy.ndarray The time series or sequence for which to compute the pan matrix profile @@ -556,9 +555,8 @@ def __init__( Parameters ---------- client : client - A Dask or Ray Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask or Ray Distributed - documentation. + A `dask`/`ray` client. Setting up a cluster is beyond the scope of this + library. Please refer to the `dask`/`ray` documentation. T : numpy.ndarray The time series or sequence for which to compute the pan matrix profile diff --git a/stumpy/aamped.py b/stumpy/aamped.py index 5eb7f9d56..21132f281 100644 --- a/stumpy/aamped.py +++ b/stumpy/aamped.py @@ -24,7 +24,7 @@ def _dask_aamped( ): """ Compute the non-normalized (i.e., without z-normalization) matrix profile with a - distributed dask cluster + `dask` cluster This is a highly distributed implementation around the Numba JIT-compiled parallelized `_aamp` function which computes the non-normalized matrix profile @@ -33,17 +33,15 @@ def _dask_aamped( Parameters ---------- dask_client : client - A Dask Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask Distributed - documentation. + A `dask` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `dask` documentation. T_A : numpy.ndarray The time series or sequence for which to compute the matrix profile T_B : numpy.ndarray The time series or sequence that will be used to annotate T_A. For every - subsequence in T_A, its nearest neighbor in T_B will be recorded. Default is - `None` which corresponds to a self-join. + subsequence in T_A, its nearest neighbor in T_B will be recorded. m : int Window size @@ -159,9 +157,157 @@ def _dask_aamped( return out +def _ray_aamped( + ray_client, + T_A, + T_B, + m, + T_A_subseq_isfinite, + T_B_subseq_isfinite, + p, + diags, + ignore_trivial, + k, +): + """ + Compute the non-normalized (i.e., without z-normalization) matrix profile with a + `ray` cluster + + This is a highly distributed implementation around the Numba JIT-compiled + parallelized `_aamp` function which computes the non-normalized matrix profile + according to AAMP. + + Parameters + ---------- + ray_client : client + A `ray` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `ray` documentation. + + T_A : numpy.ndarray + The time series or sequence for which to compute the matrix profile + + T_B : numpy.ndarray + The time series or sequence that will be used to annotate T_A. For every + subsequence in T_A, its nearest neighbor in T_B will be recorded. + + m : int + Window size + + T_A_subseq_isfinite : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_A` contains a + `np.nan`/`np.inf` value (False) + + T_B_subseq_isfinite : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_B` contains a + `np.nan`/`np.inf` value (False) + + p : float + The p-norm to apply for computing the Minkowski distance. Minkowski distance is + typically used with `p` being 1 or 2, which correspond to the Manhattan distance + and the Euclidean distance, respectively. + + diags : numpy.ndarray + The diagonal indices + + ignore_trivial : bool, default True + Set to `True` if this is a self-join. Otherwise, for AB-join, set this + to `False`. Default is `True`. + + k : int, default 1 + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. If you have access to a GPU device, then you may be able to + leverage `gpu_stump` for better performance and scalability. + + Returns + ------- + out : numpy.ndarray + When k = 1 (default), the first column consists of the matrix profile, + the second column consists of the matrix profile indices, the third column + consists of the left matrix profile indices, and the fourth column consists + of the right matrix profile indices. However, when k > 1, the output array + will contain exactly 2 * k + 2 columns. The first k columns (i.e., out[:, :k]) + consists of the top-k matrix profile, the next set of k columns + (i.e., out[:, k:2k]) consists of the corresponding top-k matrix profile + indices, and the last two columns (i.e., out[:, 2k] and out[:, 2k+1] or, + equivalently, out[:, -2] and out[:, -1]) correspond to the top-1 left + matrix profile indices and the top-1 right matrix profile indices, respectively. + """ + core.check_ray(ray_client) + + n_A = T_A.shape[0] + n_B = T_B.shape[0] + l = n_A - m + 1 + + nworkers = core.get_ray_nworkers(ray_client) + + ndist_counts = core._count_diagonal_ndist(diags, m, n_A, n_B) + diags_ranges = core._get_array_ranges(ndist_counts, nworkers, False) + diags_ranges += diags[0] + + # Scatter data to Ray cluster + T_A_ref = ray_client.put(T_A) + T_B_ref = ray_client.put(T_B) + T_A_subseq_isfinite_ref = ray_client.put(T_A_subseq_isfinite) + T_B_subseq_isfinite_ref = ray_client.put(T_B_subseq_isfinite) + + diags_refs = [] + for i in range(nworkers): + diags_ref = ray_client.put( + np.arange(diags_ranges[i, 0], diags_ranges[i, 1], dtype=np.int64) + ) + diags_refs.append(diags_ref) + + ray_aamp_func = ray_client.remote(core.deco_ray_tor(_aamp)) + + refs = [] + for i in range(nworkers): + refs.append( + ray_aamp_func.remote( + T_A_ref, + T_B_ref, + m, + T_A_subseq_isfinite_ref, + T_B_subseq_isfinite_ref, + p, + diags_refs[i], + ignore_trivial, + k, + ) + ) + + results = ray_client.get(refs) + # Must make a mutable copy from Ray's object store (ndarrays are immutable) + profile, profile_L, profile_R, indices, indices_L, indices_R = [ + arr.copy() for arr in results[0] + ] + + for i in range(1, nworkers): + P, PL, PR, I, IL, IR = results[i] # Read-only variables + # Update top-k matrix profile and matrix profile indices + core._merge_topk_PI(profile, P, indices, I) + + # Update top-1 left matrix profile and matrix profile index + mask = PL < profile_L + profile_L[mask] = PL[mask] + indices_L[mask] = IL[mask] + + # Update top-1 right matrix profile and matrix profile index + mask = PR < profile_R + profile_R[mask] = PR[mask] + indices_R[mask] = IR[mask] + + out = np.empty((l, 2 * k + 2), dtype=object) + out[:, :k] = profile + out[:, k : 2 * k + 2] = np.column_stack((indices, indices_L, indices_R)) + + return out + + def aamped(client, T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): """ Compute the non-normalized (i.e., without z-normalization) matrix profile + with a `dask`/`ray` cluster This is a highly distributed implementation around the Numba JIT-compiled parallelized `_aamp` function which computes the non-normalized matrix profile @@ -170,9 +316,8 @@ def aamped(client, T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): Parameters ---------- client : client - A Dask or Ray Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask or Ray Distributed - documentation. + A `dask`/`ray` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `dask`/`ray` documentation. T_A : numpy.ndarray The time series or sequence for which to compute the matrix profile diff --git a/stumpy/core.py b/stumpy/core.py index d40cbc81c..af3bf7a3b 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -3733,10 +3733,10 @@ def _client_to_func(client): """ if client.__class__.__name__.startswith("Client"): prefix = "_dask_" - # elif inspect.ismodule(client) and str(client).startswith( - # "`__ + + discords : bool + When set to `True`, this reverses the distance profile to favor discords rather + than motifs. Note that indices in `include` are still maintained and respected. + + Returns + ------- + P : numpy.ndarray + The multi-dimensional matrix profile. Each row of the array corresponds + to each matrix profile for a given dimension (i.e., the first row is + the 1-D matrix profile and the second row is the 2-D matrix profile). + + I : numpy.ndarray + The multi-dimensional matrix profile index where each row of the array + corresponds to each matrix profile index for a given dimension. + """ + core.check_ray(ray_client) + + d, n = T_B.shape + l = n - m + 1 + P = np.empty((d, l), dtype=np.float64) + I = np.empty((d, l), dtype=np.int64) + + nworkers = core.get_ray_nworkers(ray_client) + + step = int(math.ceil(l / nworkers)) + + for start in range(0, l, step): + P[:, start], I[:, start] = _get_first_maamp_profile( + start, + T_A, + T_B, + m, + excl_zone, + T_B_subseq_isfinite, + p, + include, + discords, + ) + + # Put data into Ray object storage + T_A_ref = ray_client.put(T_A) + T_A_subseq_isfinite_ref = ray_client.put(T_A_subseq_isfinite) + T_B_subseq_isfinite_ref = ray_client.put(T_B_subseq_isfinite) + + p_norm_refs = [] + p_norm_first_refs = [] + + for start in range(0, l, step): + p_norm, p_norm_first = _get_multi_p_norm(start, T_A, m) + + p_norm_ref = ray_client.put(p_norm) + p_norm_first_ref = ray_client.put(p_norm_first) + + p_norm_refs.append(p_norm_ref) + p_norm_first_refs.append(p_norm_first_ref) + + ray_maamp_func = ray_client.remote(core.deco_ray_tor(_maamp)) + + refs = [] + for i, start in enumerate(range(0, l, step)): + stop = min(l, start + step) + + refs.append( + ray_maamp_func.remote( + T_A_ref, + m, + stop, + excl_zone, + T_A_subseq_isfinite_ref, + T_B_subseq_isfinite_ref, + p, + p_norm_refs[i], + p_norm_first_refs[i], + l, + start + 1, + include, + discords, + ) + ) + + results = ray_client.get(refs) + for i, start in enumerate(range(0, l, step)): + stop = min(l, start + step) + P[:, start + 1 : stop], I[:, start + 1 : stop] = results[i] + + return P, I + + def maamped(client, T, m, include=None, discords=False, p=2.0): """ Compute the multi-dimensional non-normalized (i.e., without z-normalization) matrix - profile with a distributed dask cluster + profile with a `dask`/`ray` cluster This is a highly distributed implementation around the Numba JIT-compiled parallelized `_maamp` function which computes the multi-dimensional matrix @@ -177,9 +329,8 @@ def maamped(client, T, m, include=None, discords=False, p=2.0): Parameters ---------- client : client - A Dask or Ray Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask or Ray Distributed - documentation. + A `dask`/`ray` client. Setting up a cluster is beyond the scope of this + library. Please refer to the `dask`/`ray` documentation. T : numpy.ndarray The time series or sequence for which to compute the multi-dimensional diff --git a/stumpy/mpdist.py b/stumpy/mpdist.py index 74e1f3675..489caa9f3 100644 --- a/stumpy/mpdist.py +++ b/stumpy/mpdist.py @@ -356,13 +356,25 @@ def mpdisted( >>> import numpy as np >>> from dask.distributed import Client >>> if __name__ == "__main__": - ... with Client() as dask_client: - ... stumpy.mpdisted( + >>> with Client() as dask_client: + >>> stumpy.mpdisted( ... dask_client, ... np.array([-11.1, 23.4, 79.5, 1001.0]), ... np.array([584., -11., 23., 79., 1001., 0., -19.]), ... m=3) 0.00019935236191097894 + + Alternatively, you can also use `ray` + + >>> import ray + >>> if __name__ == "__main__": + >>> ray.init() + >>> stumpy.mpdisted( + ... ray, + ... np.array([-11.1, 23.4, 79.5, 1001.0]), + ... np.array([584., -11., 23., 79., 1001., 0., -19.]), + ... m=3) + >>> ray.shutdown() """ partial_mp_func = functools.partial( stumped, diff --git a/stumpy/mstumped.py b/stumpy/mstumped.py index 856a88844..5e369a20d 100644 --- a/stumpy/mstumped.py +++ b/stumpy/mstumped.py @@ -2,6 +2,8 @@ # Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license. # STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved. +import math + import numpy as np from . import config, core @@ -25,8 +27,7 @@ def _dask_mstumped( discords, ): """ - Compute the multi-dimensional z-normalized matrix profile with a ``dask``/``ray`` - cluster + Compute the multi-dimensional z-normalized matrix profile with a `dask` cluster This is a highly distributed implementation around the Numba JIT-compiled parallelized `_mstump` function which computes the multi-dimensional matrix @@ -47,8 +48,7 @@ def _dask_mstumped( T_B : numpy.ndarray The time series or sequence that will be used to annotate T_A. For every - subsequence in T_A, its nearest neighbor in T_B will be recorded. Default is - `None` which corresponds to a self-join. + subsequence in T_A, its nearest neighbor in T_B will be recorded. m : int Window size @@ -107,9 +107,9 @@ def _dask_mstumped( hosts = list(dask_client.ncores().keys()) nworkers = len(hosts) - step = 1 + l // nworkers + step = int(math.ceil(l / nworkers)) - for i, start in enumerate(range(0, l, step)): + for start in range(0, l, step): P[:, start], I[:, start] = _get_first_mstump_profile( start, T_A, @@ -185,6 +185,177 @@ def _dask_mstumped( return P, I +def _ray_mstumped( + ray_client, + T_A, + T_B, + m, + excl_zone, + M_T, + Σ_T, + μ_Q, + σ_Q, + T_subseq_isconstant, + Q_subseq_isconstant, + include, + discords, +): + """ + Compute the multi-dimensional z-normalized matrix profile with a `ray` cluster + + This is a highly distributed implementation around the Numba JIT-compiled + parallelized `_mstump` function which computes the multi-dimensional matrix + profile according to STOMP. Note that only self-joins are supported. + + Parameters + ---------- + ray_client : client + A `ray` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `ray` documentation. + + T_A : numpy.ndarray + The time series or sequence for which to compute the multi-dimensional + matrix profile. Each row in `T_A` represents data from a different + dimension while each column in `T_A` represents data from the same + dimension. + + T_B : numpy.ndarray + The time series or sequence that will be used to annotate T_A. For every + subsequence in T_A, its nearest neighbor in T_B will be recorded. + + m : int + Window size + + excl_zone : int + The half width for the exclusion zone relative to the current + sliding window + + M_T : numpy.ndarray + Sliding mean of time series, `T` + + Σ_T : numpy.ndarray + Sliding standard deviation of time series, `T` + + μ_Q : numpy.ndarray + Mean of the query sequence, `Q`, relative to the current sliding window + + σ_Q : numpy.ndarray + Standard deviation of the query sequence, `Q`, relative to the current + sliding window + + T_subseq_isconstant : numpy.ndarray + A boolearn array representing Rolling isconstant for `T` + + Q_subseq_isconstant : numpy.ndarray + A boolearn array representing Rolling isconstant for `Q` + + include : numpy.ndarray + A list of (zero-based) indices corresponding to the dimensions in `T` that + must be included in the constrained multidimensional motif search. + For more information, see Section IV D in: + + `DOI: 10.1109/ICDM.2017.66 \ + `__ + + discords : bool + When set to `True`, this reverses the distance profile to favor discords rather + than motifs. Note that indices in `include` are still maintained and respected. + + Returns + ------- + P : numpy.ndarray + The multi-dimensional matrix profile. Each row of the array corresponds + to each matrix profile for a given dimension (i.e., the first row is + the 1-D matrix profile and the second row is the 2-D matrix profile). + + I : numpy.ndarray + The multi-dimensional matrix profile index where each row of the array + corresponds to each matrix profile index for a given dimension. + """ + core.check_ray(ray_client) + + d, n = T_B.shape + l = n - m + 1 + P = np.empty((d, l), dtype=np.float64) + I = np.empty((d, l), dtype=np.int64) + + nworkers = core.get_ray_nworkers(ray_client) + + step = int(math.ceil(l / nworkers)) + + for start in range(0, l, step): + P[:, start], I[:, start] = _get_first_mstump_profile( + start, + T_A, + T_B, + m, + excl_zone, + M_T, + Σ_T, + μ_Q, + σ_Q, + T_subseq_isconstant, + Q_subseq_isconstant, + include, + discords, + ) + + # Put data into Ray object storage + T_A_ref = ray_client.put(T_A) + M_T_ref = ray_client.put(M_T) + Σ_T_ref = ray_client.put(Σ_T) + μ_Q_ref = ray_client.put(μ_Q) + σ_Q_ref = ray_client.put(σ_Q) + T_subseq_isconstant_ref = ray_client.put(T_subseq_isconstant) + Q_subseq_isconstant_ref = ray_client.put(Q_subseq_isconstant) + + QT_refs = [] + QT_first_refs = [] + + for start in range(0, l, step): + QT, QT_first = _get_multi_QT(start, T_A, m) + + QT_ref = ray_client.put(QT) + QT_first_ref = ray_client.put(QT_first) + + QT_refs.append(QT_ref) + QT_first_refs.append(QT_first_ref) + + ray_mstump_func = ray_client.remote(core.deco_ray_tor(_mstump)) + + refs = [] + for i, start in enumerate(range(0, l, step)): + stop = min(l, start + step) + + refs.append( + ray_mstump_func.remote( + T_A_ref, + m, + stop, + excl_zone, + M_T_ref, + Σ_T_ref, + QT_refs[i], + QT_first_refs[i], + μ_Q_ref, + σ_Q_ref, + T_subseq_isconstant_ref, + Q_subseq_isconstant_ref, + l, + start + 1, + include, + discords, + ) + ) + + results = ray_client.get(refs) + for i, start in enumerate(range(0, l, step)): + stop = min(l, start + step) + P[:, start + 1 : stop], I[:, start + 1 : stop] = results[i] + + return P, I + + @core.non_normalized( maamped, exclude=["normalize", "T_subseq_isconstant"], @@ -291,6 +462,7 @@ def mstumped( >>> if __name__ == "__main__": ... with Client() as dask_client: ... stumpy.mstumped( + ... dask_client, ... np.array([[584., -11., 23., 79., 1001., 0., -19.], ... [ 1., 2., 4., 8., 16., 0., 32.]]), ... m=3) @@ -298,6 +470,18 @@ def mstumped( [0.777905 , 2.36179922, 1.50004632, 2.92246722, 0.777905 ]]), array([[2, 4, 0, 1, 0], [4, 4, 0, 1, 0]])) + + Alternatively, you can also use `ray` + + >>> import ray + >>> if __name__ == "__main__": + >>> ray.init() + >>> stumpy.mstumped( + ... ray, + ... np.array([[584., -11., 23., 79., 1001., 0., -19.], + ... [ 1., 2., 4., 8., 16., 0., 32.]]), + ... m=3) + >>> ray.shutdown() """ T_A = T T_B = T_A diff --git a/stumpy/ostinato.py b/stumpy/ostinato.py index ff4bb1ec3..424da525f 100644 --- a/stumpy/ostinato.py +++ b/stumpy/ostinato.py @@ -484,14 +484,27 @@ def ostinatoed(client, Ts, m, normalize=True, p=2.0, Ts_subseq_isconstant=None): >>> import numpy as np >>> from dask.distributed import Client >>> if __name__ == "__main__": - ... with Client() as dask_client: - ... stumpy.ostinatoed( + >>> with Client() as dask_client: + >>> stumpy.ostinatoed( ... dask_client, ... [np.array([584., -11., 23., 79., 1001., 0., 19.]), ... np.array([600., -10., 23., 17.]), ... np.array([ 1., 9., 6., 0.])], ... m=3) (1.2370237678153826, 0, 4) + + Alternatively, you can also use `ray` + + >>> import ray + >>> if __name__ == "__main__": + >>> ray.init() + >>> stumpy.ostinatoed( + ... ray, + ... [np.array([584., -11., 23., 79., 1001., 0., 19.]), + ... np.array([600., -10., 23., 17.]), + ... np.array([ 1., 9., 6., 0.])], + ... m=3) + >>> ray.shutdown() """ if not isinstance(Ts, list): # pragma: no cover raise ValueError(f"`Ts` is of type `{type(Ts)}` but a `list` is expected") diff --git a/stumpy/stimp.py b/stumpy/stimp.py index 4d0d4c64b..446b28a2a 100644 --- a/stumpy/stimp.py +++ b/stumpy/stimp.py @@ -644,6 +644,16 @@ class stimped(_stimp): ... pmp.PAN_ array([[0., 1., 1., 1., 1., 1., 1.], [0., 1., 1., 1., 1., 1., 1.]]) + + Alternatively, you can also use `ray` + + >>> import ray + >>> if __name__ == "__main__": + >>> ray.init() + >>> pmp = stumpy.stimped( + ... ray, + ... np.array([584., -11., 23., 79., 1001., 0., -19.])) + >>> ray.shutdown() """ def __init__( diff --git a/stumpy/stumped.py b/stumpy/stumped.py index 35866ca72..e5abc75c1 100644 --- a/stumpy/stumped.py +++ b/stumpy/stumped.py @@ -30,7 +30,7 @@ def _dask_stumped( k, ): """ - Compute the z-normalized (top-k) matrix profile with a distributed dask cluster + Compute the z-normalized (top-k) matrix profile with a `dask` cluster This is a highly distributed implementation around the Numba JIT-compiled parallelized `_stump` function which computes the (top-k) matrix profile according @@ -39,17 +39,15 @@ def _dask_stumped( Parameters ---------- dask_client : client - A Dask Distributed client. Setting up a distributed cluster is beyond - the scope of this library. Please refer to the Dask Distributed - documentation. + A `dask` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `dask` documentation. T_A : numpy.ndarray The time series or sequence for which to compute the matrix profile T_B : numpy.ndarray The time series or sequence that will be used to annotate T_A. For every - subsequence in T_A, its nearest neighbor in T_B will be recorded. Default is - `None` which corresponds to a self-join. + subsequence in T_A, its nearest neighbor in T_B will be recorded. m : int Window size @@ -157,7 +155,7 @@ def _dask_stumped( diags_futures.append(diags_future) futures = [] - for i in range(len(hosts)): + for i in range(nworkers): futures.append( dask_client.submit( _stump, @@ -183,7 +181,7 @@ def _dask_stumped( results = dask_client.gather(futures) profile, profile_L, profile_R, indices, indices_L, indices_R = results[0] - for i in range(1, len(hosts)): + for i in range(1, nworkers): P, PL, PR, I, IL, IR = results[i] # Update top-k matrix profile and matrix profile indices core._merge_topk_PI(profile, P, indices, I) @@ -205,6 +203,194 @@ def _dask_stumped( return out +def _ray_stumped( + ray_client, + T_A, + T_B, + m, + M_T, + μ_Q, + Σ_T_inverse, + σ_Q_inverse, + M_T_m_1, + μ_Q_m_1, + T_A_subseq_isfinite, + T_B_subseq_isfinite, + T_A_subseq_isconstant, + T_B_subseq_isconstant, + diags, + ignore_trivial, + k, +): + """ + Compute the z-normalized (top-k) matrix profile with a `ray` cluster + + This is a highly distributed implementation around the Numba JIT-compiled + parallelized `_stump` function which computes the (top-k) matrix profile according + to STOMPopt with Pearson correlations. + + Parameters + ---------- + ray_client : client + A `ray` client. Setting up a cluster is beyond the scope of this library. + Please refer to the `ray` documentation. + + T_A : numpy.ndarray + The time series or sequence for which to compute the matrix profile + + T_B : numpy.ndarray + The time series or sequence that will be used to annotate T_A. For every + subsequence in T_A, its nearest neighbor in T_B will be recorded. + + m : int + Window size + + M_T : numpy.ndarray + Sliding mean of time series, `T` + + μ_Q : numpy.ndarray + Mean of the query sequence, `Q`, relative to the current sliding window + + Σ_T_inverse : numpy.ndarray + Inverse sliding standard deviation of time series, `T` + + σ_Q_inverse : numpy.ndarray + Inverse standard deviation of the query sequence, `Q`, relative to the current + + M_T_m_1 : numpy.ndarray + Sliding mean of time series, `T`, using a window size of `m-1` + + μ_Q_m_1 : numpy.ndarray + Mean of the query sequence, `Q`, relative to the current sliding window and + using a window size of `m-1` + + T_A_subseq_isfinite : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_A` contains a + `np.nan`/`np.inf` value (False) + + T_B_subseq_isfinite : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_B` contains a + `np.nan`/`np.inf` value (False) + + T_A_subseq_isconstant : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_A` is constant (True) + + T_B_subseq_isconstant : numpy.ndarray + A boolean array that indicates whether a subsequence in `T_B` is constant (True) + + diags : numpy.ndarray + The diagonal indices + + ignore_trivial : bool, default True + Set to `True` if this is a self-join. Otherwise, for AB-join, set this + to `False`. Default is `True`. + + k : int, default 1 + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. If you have access to a GPU device, then you may be able to + leverage `gpu_stump` for better performance and scalability. + + Returns + ------- + out : numpy.ndarray + When k = 1 (default), the first column consists of the matrix profile, + the second column consists of the matrix profile indices, the third column + consists of the left matrix profile indices, and the fourth column consists + of the right matrix profile indices. However, when k > 1, the output array + will contain exactly 2 * k + 2 columns. The first k columns (i.e., out[:, :k]) + consists of the top-k matrix profile, the next set of k columns + (i.e., out[:, k:2k]) consists of the corresponding top-k matrix profile + indices, and the last two columns (i.e., out[:, 2k] and out[:, 2k+1] or, + equivalently, out[:, -2] and out[:, -1]) correspond to the top-1 left + matrix profile indices and the top-1 right matrix profile indices, respectively. + """ + core.check_ray(ray_client) + + n_A = T_A.shape[0] + n_B = T_B.shape[0] + l = n_A - m + 1 + + nworkers = core.get_ray_nworkers(ray_client) + + ndist_counts = core._count_diagonal_ndist(diags, m, n_A, n_B) + diags_ranges = core._get_array_ranges(ndist_counts, nworkers, False) + diags_ranges += diags[0] + + # Put data in the Ray object store + T_A_ref = ray_client.put(T_A) + T_B_ref = ray_client.put(T_B) + M_T_ref = ray_client.put(M_T) + μ_Q_ref = ray_client.put(μ_Q) + Σ_T_inverse_ref = ray_client.put(Σ_T_inverse) + σ_Q_inverse_ref = ray_client.put(σ_Q_inverse) + M_T_m_1_ref = ray_client.put(M_T_m_1) + μ_Q_m_1_ref = ray_client.put(μ_Q_m_1) + T_A_subseq_isfinite_ref = ray_client.put(T_A_subseq_isfinite) + T_B_subseq_isfinite_ref = ray_client.put(T_B_subseq_isfinite) + T_A_subseq_isconstant_ref = ray_client.put(T_A_subseq_isconstant) + T_B_subseq_isconstant_ref = ray_client.put(T_B_subseq_isconstant) + + diags_refs = [] + for i in range(nworkers): + diags_ref = ray_client.put( + np.arange(diags_ranges[i, 0], diags_ranges[i, 1], dtype=np.int64), + ) + diags_refs.append(diags_ref) + + ray_stump_func = ray_client.remote(core.deco_ray_tor(_stump)) + + refs = [] + for i in range(nworkers): + refs.append( + ray_stump_func.remote( + T_A_ref, + T_B_ref, + m, + μ_Q_ref, + M_T_ref, + σ_Q_inverse_ref, + Σ_T_inverse_ref, + μ_Q_m_1_ref, + M_T_m_1_ref, + T_A_subseq_isfinite_ref, + T_B_subseq_isfinite_ref, + T_A_subseq_isconstant_ref, + T_B_subseq_isconstant_ref, + diags_refs[i], + ignore_trivial, + k, + ) + ) + + results = ray_client.get(refs) + # Must make a mutable copy from Ray's object store (ndarrays are immutable) + profile, profile_L, profile_R, indices, indices_L, indices_R = [ + arr.copy() for arr in results[0] + ] + + for i in range(1, nworkers): + P, PL, PR, I, IL, IR = results[i] # Read-only variables + # Update top-k matrix profile and matrix profile indices + core._merge_topk_PI(profile, P, indices, I) + + # Update top-1 left matrix profile and matrix profile index + mask = PL < profile_L + profile_L[mask] = PL[mask] + indices_L[mask] = IL[mask] + + # Update top-1 right matrix profile and matrix profile index + mask = PR < profile_R + profile_R[mask] = PR[mask] + indices_R[mask] = IR[mask] + + out = np.empty((l, 2 * k + 2), dtype=object) + out[:, :k] = profile + out[:, k:] = np.column_stack((indices, indices_L, indices_R)) + + return out + + @core.non_normalized(aamped) def stumped( client, @@ -371,7 +557,7 @@ def stumped( >>> from dask.distributed import Client >>> if __name__ == "__main__": ... with Client() as dask_client: - ... mp = stumpy.stumped( + ... stumpy.stumped( ... dask_client, ... np.array([584., -11., 23., 79., 1001., 0., -19.]), ... m=3) @@ -385,6 +571,17 @@ def stumped( mparray([0.11633857, 2.69407392, 3.00009263, 2.69407392, 0.11633857]) >>> mp.I_ mparray([4, 3, 0, 1, 0]) + + Alternatively, you can also use `ray` + + >>> import ray + >>> if __name__ == "__main__": + >>> ray.init() + >>> stumpy.stumped( + ... ray, + ... np.array([584., -11., 23., 79., 1001., 0., -19.]), + ... m=3) + >>> ray.shutdown() """ if T_B is None: T_B = T_A diff --git a/test.sh b/test.sh index c8630158c..ad29f9863 100755 --- a/test.sh +++ b/test.sh @@ -5,7 +5,7 @@ print_mode="verbose" custom_testfiles=() max_iter=10 site_pkgs=$(python -c 'import site; print(site.getsitepackages()[0])') - +fcoveragexml="coverage.stumpy.xml" # Parse command line arguments for var in "$@" do @@ -21,12 +21,16 @@ do test_mode="show" elif [[ $var == "custom" ]]; then test_mode="custom" + elif [[ $var == "report" ]]; then + test_mode="report" elif [[ $var == "silent" || $var == "print" ]]; then print_mode="silent" elif [[ "$var" == *"test_"*".py"* ]]; then custom_testfiles+=("$var") elif [[ $var =~ ^[\-0-9]+$ ]]; then max_iter=$var + elif [[ "$var" == *".xml" ]]; then + fcoveragexml=$var elif [[ "$var" == "links" ]]; then test_mode="links" else @@ -101,6 +105,55 @@ check_naive() done } +check_ray() +{ + if ! command -v ray &> /dev/null + then + echo "Ray Not Installed" + else + echo "Ray Installed" + fi +} + +gen_ray_coveragerc() +{ + # Generate a .coveragerc_ray file that excludes Ray functions and tests + echo "[report]" > .coveragerc_ray + echo "; Regexes for lines to exclude from consideration" >> .coveragerc_ray + echo "exclude_also =" >> .coveragerc_ray + echo " def .*_ray_*" >> .coveragerc_ray + echo " def ,*_ray\(*" >> .coveragerc_ray + echo " def ray_.*" >> .coveragerc_ray + echo " def test_.*_ray*" >> .coveragerc_ray +} + +set_ray_coveragerc() +{ + # If `ray` command is not found then generate a .coveragerc_ray file + if ! command -v ray &> /dev/null + then + echo "Ray Not Installed" + gen_ray_coveragerc + fcoveragerc="--rcfile=.coveragerc_ray" + else + echo "Ray Installed" + fcoveragerc="" + fi +} + +show_coverage_report() +{ + set_ray_coveragerc + coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc +} + +gen_coverage_xml_report() +{ + # This function saves the coverage report in Cobertura XML format, which is compatible with codecov + set_ray_coveragerc + coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc +} + test_custom() { # export NUMBA_DISABLE_JIT=1 @@ -127,7 +180,7 @@ test_custom() for i in $(seq $max_iter) do echo "Custom Test: $i / $max_iter" - for testfile in "${custom_testfiles[@]}" + for testfile in "${custom_testfiles[@]}"; do pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile check_errs $? @@ -141,11 +194,19 @@ test_custom() test_unit() { echo "Testing Numba JIT Compiled Functions" - for testfile in tests/test_*.py - do - pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile - check_errs $? - done + if [[ ${#custom_testfiles[@]} -eq "0" ]]; then + for testfile in tests/test_*.py + do + pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile + check_errs $? + done + else + for testfile in "${custom_testfiles[@]}"; + do + pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile + check_errs $? + done + fi } test_coverage() @@ -160,12 +221,25 @@ test_coverage() echo "Testing Code Coverage" coverage erase - for testfile in tests/test_*.py - do - coverage run --append --source=. -m pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile - check_errs $? - done - coverage report -m --fail-under=100 --skip-covered --omit=setup.py,docstring.py,min.py,stumpy/cache.py + + # We always attempt to test everything but we may ignore things (ray, helper scripts) when we generate the coverage report + + if [[ ${#custom_testfiles[@]} -eq "0" ]]; then + # Execute all tests + for testfile in tests/test_*.py; + do + coverage run --append --source=. -m pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile + check_errs $? + done + else + # Execute custom tests + for testfile in "${custom_testfiles[@]}"; + do + coverage run --append --source=. -m pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning -W ignore::UserWarning $testfile + check_errs $? + done + fi + show_coverage_report } test_gpu() @@ -205,6 +279,7 @@ clean_up() rm -rf "tests/__pycache__/" rm -rf build dist stumpy.egg-info __pycache__ rm -f docs/*.nbconvert.ipynb + rm -rf ".coveragerc_ray" if [ -d "$site_pkgs/stumpy/__pycache__" ]; then rm -rf $site_pkgs/stumpy/__pycache__/*nb* fi @@ -237,6 +312,7 @@ check_flake check_docstrings check_print check_naive +check_ray if [[ $test_mode == "notebooks" ]]; then echo "Executing Tutorial Notebooks Only" @@ -254,6 +330,11 @@ elif [[ $test_mode == "custom" ]]; then # export NUMBA_DISABLE_JIT=1 # export NUMBA_ENABLE_CUDASIM=1 test_custom +elif [[ $test_mode == "report" ]]; then + echo "Generate Coverage Report Only" + # Assume coverage tests have already been executed + # and a coverage file exists + gen_coverage_xml_report elif [[ $test_mode == "gpu" ]]; then echo "Executing GPU Unit Tests Only" test_gpu diff --git a/tests/test_ray.py b/tests/test_ray.py new file mode 100644 index 000000000..6f86426d9 --- /dev/null +++ b/tests/test_ray.py @@ -0,0 +1,222 @@ +import naive +import numpy as np +import numpy.testing as npt +import pytest + +try: # pragma: no cover + import ray + + RAY_IMPORTED = True +except ImportError: # pragma: no cover + RAY_IMPORTED = False +from stumpy import aamp_stimped, aamped, maamped, mstumped, stimped, stumped + + +@pytest.fixture(scope="module") +def ray_cluster(): + try: + if not ray.is_initialized(): + ray.init() + yield None + if ray.is_initialized(): + ray.shutdown() + except NameError: # pragma: no cover + # Ray not installed + yield None + + +test_data = [ + ( + np.array([9, 8100, -60, 7], dtype=np.float64), + np.array([584, -11, 23, 79, 1001, 0, -19], dtype=np.float64), + ), + ( + np.random.uniform(-1000, 1000, [8]).astype(np.float64), + np.random.uniform(-1000, 1000, [64]).astype(np.float64), + ), +] + +test_mdata = [ + (np.array([[584, -11, 23, 79, 1001, 0, -19]], dtype=np.float64), 3), + (np.random.uniform(-1000, 1000, [5, 20]).astype(np.float64), 5), +] + +T = [ + np.array([584, -11, 23, 79, 1001, 0, -19], dtype=np.float64), + np.random.uniform(-1000, 1000, [64]).astype(np.float64), +] + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_stumped_ray_self_join(T_A, T_B, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + m = 3 + zone = int(np.ceil(m / 4)) + ref_mp = naive.stump(T_B, m, exclusion_zone=zone) + comp_mp = stumped(ray, T_B, m, ignore_trivial=True) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_aamped_ray_self_join(T_A, T_B, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + m = 3 + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_B, m, p=p) + comp_mp = aamped(ray, T_B, m, p=p) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T, m", test_mdata) +def test_mstumped_ray(T, m, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + excl_zone = int(np.ceil(m / 4)) + + ref_P, ref_I = naive.mstump(T, m, excl_zone) + comp_P, comp_I = mstumped(ray, T, m) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T, m", test_mdata) +def test_maamped_ray(T, m, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + excl_zone = int(np.ceil(m / 4)) + + ref_P, ref_I = naive.maamp(T, m, excl_zone) + comp_P, comp_I = maamped(ray, T, m) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T", T) +def test_stimped_ray(T, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + threshold = 0.2 + min_m = 3 + n = T.shape[0] - min_m + 1 + + pan = stimped( + ray, + T, + min_m=min_m, + max_m=None, + step=1, + # normalize=True, + ) + + for i in range(n): + pan.update() + + ref_PAN = np.full((pan.M_.shape[0], T.shape[0]), fill_value=np.inf) + + for idx, m in enumerate(pan.M_[:n]): + zone = int(np.ceil(m / 4)) + ref_mp = naive.stump(T, m, T_B=None, exclusion_zone=zone) + ref_PAN[pan._bfs_indices[idx], : ref_mp.shape[0]] = ref_mp[:, 0] + + # Compare raw pan + cmp_PAN = pan._PAN + + naive.replace_inf(ref_PAN) + naive.replace_inf(cmp_PAN) + + npt.assert_almost_equal(ref_PAN, cmp_PAN) + + # Compare transformed pan + cmp_pan = pan.PAN_ + ref_pan = naive.transform_pan( + pan._PAN, pan._M, threshold, pan._bfs_indices, pan._n_processed + ) + + naive.replace_inf(ref_pan) + naive.replace_inf(cmp_pan) + + npt.assert_almost_equal(ref_pan, cmp_pan) + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T", T) +def test_aamp_stimped_ray(T, ray_cluster): + if not RAY_IMPORTED: # pragma: no cover + pytest.skip("Skipping Test Ray Not Installed") + + threshold = 0.2 + min_m = 3 + n = T.shape[0] - min_m + 1 + + pan = aamp_stimped( + ray, + T, + min_m=min_m, + max_m=None, + step=1, + ) + + for i in range(n): + pan.update() + + ref_PAN = np.full((pan.M_.shape[0], T.shape[0]), fill_value=np.inf) + + for idx, m in enumerate(pan.M_[:n]): + zone = int(np.ceil(m / 4)) + ref_mp = naive.aamp(T, m, T_B=None, exclusion_zone=zone) + ref_PAN[pan._bfs_indices[idx], : ref_mp.shape[0]] = ref_mp[:, 0] + + # Compare raw pan + cmp_PAN = pan._PAN + + naive.replace_inf(ref_PAN) + naive.replace_inf(cmp_PAN) + + npt.assert_almost_equal(ref_PAN, cmp_PAN) + + # Compare transformed pan + cmp_pan = pan.PAN_ + ref_pan = naive.transform_pan( + pan._PAN, + pan._M, + threshold, + pan._bfs_indices, + pan._n_processed, + np.min(T), + np.max(T), + ) + + naive.replace_inf(ref_pan) + naive.replace_inf(cmp_pan) + + npt.assert_almost_equal(ref_pan, cmp_pan)