From 34dea83dbb894d2dec147f0b07426298a6b5cd11 Mon Sep 17 00:00:00 2001
From: "BRENOT Laure S." <133853397+LaureBrenot@users.noreply.github.com>
Date: Tue, 17 Sep 2024 23:27:03 -0800
Subject: [PATCH] Adding Wavelet-based DVV to the default workflow (#362)

* Update default.csv: wct param

* Update msnoise.py: WCT job

* Create s08compute_wct.py

* Update s08compute_wct.py: add wavelet_type param and def

* Update default.csv: add wavelet_type

* Update msnoise.py

* Update default.csv

* Update s08compute_wct.py

* Add files via upload

* Update tests.py: add compute_wct and plot tests

* Update msnoise.py with plot wct

* Update wct_dvv.py

* Update s08compute_wct.py

* Update wct_dvv.py

* Update msnoise.py

* Update wct_dvv.py: mov_stack/rolling

if mov_stack in params list use it, else rolling = mov_stack and applyied to the plot

* Update api.py: xr_save_wct in api

* Update s08compute_wct.py: move xr_save_wct in api

I don't know why tests don't build for PRs, but it's OK, we'll see later :-) (unsafe comment, don't remember it)
---
 msnoise/api.py             |  55 +++++
 msnoise/default.csv        |  10 +-
 msnoise/plots/wct_dvv.py   | 286 +++++++++++++++++++++++
 msnoise/s08compute_wct.py  | 463 +++++++++++++++++++++++++++++++++++++
 msnoise/scripts/msnoise.py |  53 ++++-
 msnoise/test/tests.py      |  23 ++
 6 files changed, 887 insertions(+), 3 deletions(-)
 create mode 100644 msnoise/plots/wct_dvv.py
 create mode 100644 msnoise/s08compute_wct.py

diff --git a/msnoise/api.py b/msnoise/api.py
index d48a0f7..2916c73 100644
--- a/msnoise/api.py
+++ b/msnoise/api.py
@@ -2556,3 +2556,58 @@ def compute_dvv(session, filterid, mov_stack, pairs=None, components=None, param
         stats[(c, "trimmed_mean")], stats[(c, "trimmed_std")] = trim(all, c, kwargs.get("limits", None))
 
     return stats.sort_index(axis=1)
+
+def xr_save_wct(station1, station2, components, filterid, mov_stack, taxis, dvv_df, err_df, coh_df):
+    """
+    Save the Wavelet Coherence Transform (WCT) results as a NetCDF file.
+    
+    Parameters
+    ----------
+    station1 : str
+        The first station in the pair.
+    station2 : str
+        The second station in the pair.
+    components : str
+        The components (e.g., Z, N, E) being analyzed.
+    filterid : int
+        Filter ID used in the analysis.
+    mov_stack : tuple
+        Tuple of (start, end) representing the moving stack window.
+    taxis : array-like
+        Time axis corresponding to the WCT data.
+    dvv_df : pandas.DataFrame
+        DataFrame containing dvv data (2D).
+    err_df : pandas.DataFrame
+        DataFrame containing err data (2D).
+    coh_df : pandas.DataFrame
+        DataFrame containing coh data (2D).
+    
+    Returns
+    -------
+    None
+    """
+    # Construct the file path
+    fn = os.path.join("WCT", f"{filterid:02d}", f"{mov_stack[0]}_{mov_stack[1]}",
+                      components, f"{station1}_{station2}.nc")
+
+    # Ensure the directory exists
+    os.makedirs(os.path.dirname(fn), exist_ok=True)
+
+    # Convert DataFrames to xarray.DataArrays
+    dvv_da = xr.DataArray(dvv_df.values, coords=[dvv_df.index, dvv_df.columns], dims=['times', 'frequency'])
+    err_da = xr.DataArray(err_df.values, coords=[err_df.index, err_df.columns], dims=['times', 'frequency'])
+    coh_da = xr.DataArray(coh_df.values, coords=[coh_df.index, coh_df.columns], dims=['times', 'frequency'])
+
+    # Combine into a single xarray.Dataset
+    ds = xr.Dataset({
+        'dvv': dvv_da,
+        'err': err_da,
+        'coh': coh_da
+    })
+
+    # Save the dataset to a NetCDF file
+    ds.to_netcdf(fn)
+
+    logger.debug(f"Saved WCT data to {fn}")
+    # Clean up
+    del dvv_da, err_da, coh_da, ds
diff --git a/msnoise/default.csv b/msnoise/default.csv
index 5890e74..af1a7c7 100644
--- a/msnoise/default.csv
+++ b/msnoise/default.csv
@@ -50,6 +50,14 @@ dtt_sides,both,Which sides to use,str,both/left/right,,
 dtt_mincoh,0.65,"Minimum coherence on dt measurement, MWCS points with values lower than that will not be used in the WLS, [0:1]",float,,,
 dtt_maxerr,0.1,"Maximum error on dt measurement, MWCS points with values larger than that will not be used in the WLS [0:1]",float,,,
 dtt_maxdt,0.1,"Maximum dt values, MWCS points with values larger than that will not be used in the WLS (in seconds)",float,,,
+wct_ns,5,"smoothing parameter in frequency",float,,,
+wct_nt,5,"smoothing parameter in time",float,,,
+wct_vpo,20,"spacing param between discrete scales",float,,,
+wct_nptsfreq,300,"number of freq points between min and max",float,,,
+dtt_codacycles,20,"number of cycles of period (1/freq) between lag_min and lag_max",int,,,
+dvv_min_nonzero,0.25,"percentage of data points with non-zero weighting required for regression otherwise nan (0 to 1)",float,,,
+wct_norm,Y,"Is the REF and CCF are normalized before computing wavelet? [Y]/N",bool,Y/N,,
+wavelet_type,"('Morlet',6.)","Wavelet type and optional associated parameter",eval,Morlet/Paul/DOG/MexicanHat,,
 plugins,,Comma separated list of plugin names. Plugins names should be importable Python modules.,str,,,
 hpc,N,Is MSNoise going to run on an HPC?,bool,Y/N,,
 stretching_max,0.01,"Maximum stretching coefficient, e.g. 0.5 = 50%, 0.01 = 1%",float,,,
@@ -62,4 +70,4 @@ qc_ppsd_period_step_octaves,0.0125,Step length on frequency axis in fraction of
 qc_ppsd_period_limits,"(0.01,100)","Set custom lower and upper end of period range (e.g. ``(0.01, 100)`` seconds). The specified lower end of period range will be set as the central period of the first bin (geometric mean of left/right edges of smoothing interval). At the upper end of the specified period range, no more additional bins will be added after the bin whose center frequency exceeds the given upper end for the first time.",eval,,,
 qc_ppsd_db_bins,"(-200, -50, 1.)",Specify the lower and upper boundary and the width of the db bins. The bin width might get adjusted to fit a number of equally spaced bins in between the given boundaries.,eval,,,
 qc_rms_frequency_ranges,"[(1.0, 20.0), (4.0, 14.0), (4.0, 40.0), (4.0, 9.0)]",Specify the frequency bounds (in Hz) to compute the RMS from PSDs,eval,,,
-qc_rms_type,DISP,What units do you want for the exported RMS,str,DISP/VEL/ACC,,
\ No newline at end of file
+qc_rms_type,DISP,What units do you want for the exported RMS,str,DISP/VEL/ACC,,
diff --git a/msnoise/plots/wct_dvv.py b/msnoise/plots/wct_dvv.py
new file mode 100644
index 0000000..625a6b0
--- /dev/null
+++ b/msnoise/plots/wct_dvv.py
@@ -0,0 +1,286 @@
+"""
+This plot shows the final output of MSNoise using the wavelet.
+
+
+Example:
+
+``msnoise cc dvv plot wct``
+
+"""
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+from matplotlib.colors import Normalize
+from matplotlib.lines import Line2D
+import matplotlib.dates as mdates
+import pandas as pd
+from ..api import *
+from datetime import datetime, timedelta
+
+def plot_dvv_heatmap(data_type, dvv_df, pair, rolling, start, end, low, high, logger, mincoh=0.5):
+    # Extracting relevant data from dvv_df
+    dvv_df = dvv_df.loc[start:end]
+    if dvv_df.empty:
+        logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.")
+        return
+    rolling_window = int(rolling)
+    
+    dvv_freq = dvv_df['dvv']
+    coh_freq = dvv_df['coh']
+
+    dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean()
+    coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean()
+
+    fig, ax = plt.subplots(figsize=(16, 10))   
+    # Scatter plot of dv/v data
+    #norm1 = plt.Normalize(vmin=np.min(dvv_freq.T), vmax=np.max(dvv_freq.T))
+
+    if data_type == 'dvv':
+        low_per = np.nanpercentile(dvv_freq, 1)
+        high_per = np.nanpercentile(dvv_freq, 99)
+        
+        ax.pcolormesh(np.asarray(dvv_freq.index), np.asarray(dvv_freq.columns), dvv_freq.T,
+            cmap=mpl.cm.seismic, edgecolors='none', vmin=low_per, vmax=high_per)
+        save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_dvv_heatmap"
+        color_bar_label = 'dv/v (%)'
+        
+    elif data_type == 'coh':
+        ax.pcolormesh(np.asarray(coh_freq.index), np.asarray(coh_freq.columns), coh_freq.T,
+                cmap='RdYlGn', edgecolors='none', vmin=mincoh, vmax=1)
+        save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_coh_heatmap"
+        color_bar_label = 'Coherence value'
+    else:
+        logger.error("Unknown data type: %s, write 'dvv' or 'coh'? " % data_type)
+        return None, None, None
+
+    #if current_config.get('plot_event', False):
+    #    plot_events(ax, current_config['event_list'], start, end)
+                
+    ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end))
+        
+    ax.set_ylabel('Frequency (Hz)', fontsize=18)
+    ax.set_title(save_name, fontsize=22)
+
+    cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02)
+    cbar1.set_label(color_bar_label, fontsize=18)
+    #norm1 = Normalize(vmin=0, vmax=1)
+
+    ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5)
+    ax.xaxis.set_minor_locator(mdates.MonthLocator())
+    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
+    fig.autofmt_xdate()
+    
+    # Adjust the layout if necessary
+    fig.subplots_adjust(right=0.85)
+    fig.tight_layout()
+    return fig, save_name
+
+def plot_dvv_scatter(dvv_df, pair, rolling, start, end, ranges, logger):
+    # Extracting relevant data from dvv_df
+    dvv_df = dvv_df.loc[start:end]
+    if dvv_df.empty:
+        logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.")
+        return
+    rolling_window = int(rolling)
+
+    color = ['Blues', 'Reds','Greens','Greys'] #'Purples'
+    color2 = ['blue', 'red', 'green', 'grey']  # Colors for different frequency ranges
+    freq_names = []
+    legend_handles = []
+    ranges_list  = [list(map(float, r.strip().strip('[]').split(','))) for r in ranges.split('], [')]
+
+    fig, ax = plt.subplots(figsize=(16, 10))
+    # Loop through the frequency ranges specified in current_config
+    for i, freqrange in enumerate(ranges_list):#[[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]):#, [0.5, 2.0]]):
+        freq_name = f"{freqrange[0]}-{freqrange[1]} Hz"
+        freq_names.append(freq_name)
+
+        freqs = np.asarray(dvv_df['dvv'].columns)
+        filtered_freqs = freqs[(freqs >= freqrange[0]) & (freqs <= freqrange[1])].tolist()
+        dvv_freq = dvv_df['dvv'][filtered_freqs]        
+        coh_freq = dvv_df['coh'][filtered_freqs]
+
+        dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean()
+        coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean()
+
+        # Scatter plot of dv/v data
+        norm1 = plt.Normalize(vmin=0, vmax=1)
+        ax.scatter([0,1], [0,1], c=[0,1], cmap=color[-1])
+
+        sc = ax.scatter(dvv_freq.index, dvv_freq.mean(axis=1), c=coh_freq.mean(axis=1), cmap=color[i], norm=norm1, label=freq_name)
+        legend_handles.append(Line2D([0], [0], marker='o', color=color2[i], markerfacecolor=color2[i], markersize=10, label=freq_name))
+
+    #if current_config.get('plot_event', False):
+    #    plot_events(ax, current_config['event_list'], start, end)
+                
+    ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end))
+    #if current_config.get('same_dvv_scale', False):
+    #    ax.set_ylim(current_config['dvv_min'], current_config['dvv_max'])
+        
+    ax.set_ylabel('dv/v (%)', fontsize=18)
+    ax.set_title(f"{pair[0]} dv/v scatter plot", fontsize=22)
+
+    legend1 = ax.legend(handles=legend_handles, fontsize=22, loc='upper left')
+    ax.add_artist(legend1)
+
+    cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02)
+    cbar1.set_label('Coherence value \n(darkness of the point)', fontsize=18)
+    norm1 = Normalize(vmin=0, vmax=1)
+
+    ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5)
+    ax.xaxis.set_minor_locator(mdates.MonthLocator())
+    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
+    fig.autofmt_xdate()
+    
+    # Adjust the layout if necessary
+    fig.subplots_adjust(right=0.85)
+    fig.tight_layout()
+    
+    return fig, f"{pair[0]}_{ np.min(ranges_list)}_{np.max(ranges_list)}_Hz_m{rolling_window}_wctscatter"
+    
+
+def save_figure(fig, filename, logger, mov_stack, components,filterid, visualize, plot_all_period=False, start=None, end=None, outfile=None):
+    fig_path = os.path.join('Figures' if plot_all_period else 'Figures/Zooms')
+    create_folder(fig_path, logger)
+    mov_stack= mov_stack[0]
+    if start and end:
+        filename = f'{filename}_{str(start)[:10]}_{str(end)[:10]}'
+    filepath = os.path.join(fig_path, f'{filename}.png')
+
+    if outfile:
+        if outfile.startswith("?"):
+            if len(mov_stack) == 1:
+                outfile = outfile.replace('?', '%s-f%i-m%s_%s-%s' % (components,
+                                                                   filterid,
+                                                                   mov_stack[0],
+                                                                   mov_stack[1],
+                                                                   visualize))
+            else:
+                outfile = outfile.replace('?', '%s-f%i-%s' % (components,
+                                                               filterid,
+                                                               visualize))
+        filepath = "wct " + outfile
+        logger.info("output to: %s" % outfile)
+
+    fig.savefig(filepath, dpi=300, bbox_inches='tight', transparent=True)
+    
+def create_folder(folder_path, logger):
+    try:
+        os.makedirs(folder_path)
+        logger.info(f"Folder '{folder_path}' created successfully.")
+    except FileExistsError:
+        pass
+
+def xr_get_wct_pair(pair, components, filterid, mov_stack, logger):
+    fn = os.path.join("WCT", "%02i" % filterid,
+                      "%s_%s" % (mov_stack[0], mov_stack[1]),
+                      "%s" % components, "%s.nc" % (pair))
+    if not os.path.isfile(fn):
+        logger.error("FILE DOES NOT EXIST: %s, skipping" % fn)
+    
+    data = xr_create_or_open(fn, name="WCT")
+    data = data.to_dataframe().unstack(level='frequency')
+    return data
+
+def xr_get_wct(components, filterid, mov_stack, logger):
+    fn = os.path.join("WCT", "%02i" % filterid,
+                      "%s_%s" % (mov_stack[0], mov_stack[1]),
+                      "%s" % components, "*.nc" )
+    matching_files = glob.glob(fn)
+    if not matching_files:
+       logger.error(f"No files found matching pattern: {fn}")
+
+    all_wct = []
+    for fil in matching_files:
+        data = xr_create_or_open(fil, name="WCT")
+        data = data.to_dataframe().unstack(level='frequency')
+        all_wct.append(data)
+    combined = pd.concat(all_wct, axis=0)
+    dvv = combined.groupby(combined.index).mean()
+    return dvv
+
+def validate_and_adjust_date(date_string, end_date, logger):
+    try:
+        start_date = datetime.strptime(date_string, '%Y-%m-%d')
+    except ValueError:
+        try:
+            days_delta = int(date_string)
+            start_date = datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=days_delta)
+        except ValueError:
+            logger.error(f"Invalid start string: {date_string}")
+            return None
+    return start_date
+
+def main(mov_stackid=None, components='ZZ', filterid=1,
+        pairs=[], showALL=False, start="1970-01-01", end="2100-01-01", visualize='dvv', ranges="[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]", show=True,outfile=None, loglevel="INFO"):
+    logger = get_logger('msnoise.cc_dvv_plot_dvv', loglevel,
+                        with_pid=True)
+    db = connect()
+    params = get_params(db)
+    mincoh = params.dtt_mincoh
+
+    # Check start and end dates
+    if start == "1970-01-01":
+         start= params.startdate
+    else:
+        start = validate_and_adjust_date(start, end, logger)
+    if end == "2100-01-01":
+        end = params.enddate
+
+    # TODO clearer  mov_stackid to additionnal rolling
+    if mov_stackid and mov_stackid != "": #if mov_stackid given
+        try:
+            mov_stack = params.mov_stack[mov_stackid - 1]
+            if mov_stack in params.mov_stack:  # Check if mov_stack is in params.mov_stack
+                mov_stacks = [mov_stack, ]
+                rolling = 1
+            else:
+                rolling = mov_stackid  # Assign  mov_stack to rolling
+        except:
+            mov_stack = params.mov_stack[0]
+            if mov_stack in params.mov_stack:  # Check if mov_stack is in params.mov_stack new format
+                mov_stacks = [mov_stack, ]
+                rolling = 1  # Keeping the mov_stack result
+            else:
+                rolling = mov_stack  # Assign mov_stack to rolling
+    else:
+        mov_stacks = params.mov_stack
+        rolling = int(params.mov_stack[0][0][0])
+
+    if components.count(","):
+        components = components.split(",")
+    else:
+        components = [components, ]
+
+    filter = get_filters(db, ref=filterid)
+    low = float(filter.low)
+    high = float(filter.high)
+
+    for i, mov_stack in enumerate(mov_stacks):
+        for comps in components:
+            # Get the data
+            if not pairs:
+                dvv = xr_get_wct(comps, filterid, mov_stack, logger)
+                pairs = ["all stations",]
+            else:
+                try:
+                    dvv = xr_get_wct_pair(pairs, comps, filterid, mov_stack, logger)    
+                except FileNotFoundError as fullpath:
+                    logger.error("FILE DOES NOT EXIST: %s, skipping" % fullpath)
+                    continue
+            # Plotting
+            if visualize == 'dvv':
+                fig, savename = plot_dvv_heatmap('dvv', dvv, pairs, rolling, start, end, low, high, logger, mincoh)
+            elif visualize == 'coh':
+                fig, savename = plot_dvv_heatmap('coh', dvv, pairs, rolling, start, end, low, high, logger, mincoh)
+            elif visualize == 'curve':
+                fig, savename = plot_dvv_scatter(dvv, pairs, rolling, start, end, ranges, logger)
+            else:
+                looger.error("PLOT TYPE DOES NOT EXIST: %s" % visualize)
+            # Save and show the figure
+            save_figure(fig, savename, logger, mov_stacks, comps, filterid, visualize, plot_all_period=False, start=start, end=end, outfile=outfile)
+            if show:
+                plt.show()
+
+if __name__ == "__main__":
+    main()
+
diff --git a/msnoise/s08compute_wct.py b/msnoise/s08compute_wct.py
new file mode 100644
index 0000000..9ff3b30
--- /dev/null
+++ b/msnoise/s08compute_wct.py
@@ -0,0 +1,463 @@
+"""
+Wavelet Coherence Transform (WCT) Computation
+This script performs the computation of the Wavelet Coherence Transform (WCT), a tool used to analyze the correlation between two time series in the time-frequency domain. The script supports parallel processing and interacts with a database to manage job statuses.
+Filter Configuration Parameters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+* |dtt_minlag| : Minimum lag time for DTT (Differential Time Travel) analysis.
+* |dtt_maxdt| : Maximum allowable time difference for coherence calculation.
+* |dtt_mincoh| : Minimum coherence required for DTT job inclusion.
+* |dtt_codacycles| : Number of coda cycles to consider in the computation.
+* |wct_ns| : Smoothing parameter for wavelet coherence transform.
+* |wct_nt| : Smoothing parameter for wavelet coherence transform.
+* |wct_vpo| : Spacing parameter between discrete scales in the wavelet transform.
+* |wct_nptsfreq| : Number of frequency points between `freqmin` and `freqmax`.
+* |dtt_min_nonzero| : Minimum percentage of non-zero weights required for the DTT computation.
+* |wct_norm| : Flag indicating whether to normalize waveforms before processing.
+* |hpc| | 
+This process is job-based, so it is possible to run several instances in
+parallel.
+Once done, each job is marked "D"one in the database and, unless ``hpc`` is 
+``Y``, DTT jobs are inserted/updated in the database.
+To run this step:
+.. code-block:: sh
+    $ msnoise cc dvv compute_wct
+This step also supports parallel processing/threading:
+.. code-block:: sh
+    $ msnoise -t 4 cc dvv compute_wct
+will start 4 instances of the code (after 1 second delay to avoid database
+conflicts). This works both with SQLite and MySQL but be aware problems
+could occur with SQLite.
+    Parallel Processing
+"""
+
+import os
+import time
+import numpy as np
+import pandas as pd
+import xarray as xr
+import scipy.optimize
+import scipy.signal
+import pycwt as wavelet
+from scipy.signal import convolve2d
+from obspy.signal.regression import linear_regression
+from .api import *
+import logbook
+import scipy
+import scipy.fft as sf
+
+def get_avgcoh(freqs, tvec, wcoh, freqmin, freqmax, lag_min=5, coda_cycles=20):
+    """
+    Calculate the average coherence over specified frequency range and time lags.
+    Parameters
+    ----------
+    freqs : numpy.ndarray
+        Array of frequency values.
+    tvec : numpy.ndarray
+        Time vector.
+    wcoh : numpy.ndarray
+        Wavelet coherence array.
+    freqmin : float
+        Minimum frequency for coherence calculation.
+    freqmax : float
+        Maximum frequency for coherence calculation.
+    lag_min : int, optional
+        Minimum lag in seconds for coherence calculation. Default is 5.
+    coda_cycles : int, optional
+        Number of coda cycles to consider. Default is 20.
+    Returns
+    -------
+    numpy.ndarray
+        Average coherence values over the specified frequency range.
+    """
+    inx = np.where((freqs>=freqmin) & (freqs<=freqmax)) 
+    coh = np.zeros(inx[0].shape) # Create empty vector for coherence
+
+    for ii, ifreq in enumerate(inx[0]): # Loop through frequencies index     
+        period = 1.0/freqs[ifreq]
+        lag_max = lag_min + (period*coda_cycles) 
+        tindex = np.where(((tvec >= -lag_max) & (tvec <= -lag_min)) | ((tvec >= lag_min) & (tvec <= lag_max)))[0] # Index of the coda
+
+        if len(tvec)>2: # check time vector size
+            if not np.any(wcoh[ifreq]): # check non-empty dt array
+                continue
+            c = np.nanmean(wcoh[ifreq][tindex])
+            coh[ii] = c
+
+        else:
+            logger.debug('not enough points to compute average coherence') #not sure why it would ever get here, but just in case.
+            coh[ii] = np.nan
+
+    return coh
+
+def smoothCFS(cfs, scales, dt, ns, nt):
+    """
+    Smooth the continuous wavelet transform coefficients using a Fourier domain approach.
+    Parameters
+    ----------
+    cfs : numpy.ndarray
+        Continuous wavelet transform coefficients.
+    scales : numpy.ndarray
+        Scales used in the wavelet transform.
+    dt : float
+        Sampling interval.
+    ns : int
+        Smoothing parameter for the moving average filter.
+    nt : float
+        Smoothing parameter for the Gaussian filter.
+    Returns
+    -------
+    numpy.ndarray
+        Smoothed continuous wavelet transform coefficients.
+    """
+    N = cfs.shape[1]
+    npad = sf.next_fast_len(N, real=True)
+    omega = np.arange(1, np.fix(npad / 2) + 1, 1).tolist()
+    omega = np.array(omega) * ((2 * np.pi) / npad)
+    omega_save = -omega[int(np.fix((npad - 1) / 2)) - 1:0:-1]
+    omega_2 = np.concatenate((0., omega), axis=None)
+    omega_2 = np.concatenate((omega_2, omega_save), axis=None)
+    omega = np.concatenate((omega_2, -omega[0]), axis=None)
+    # Normalize scales by DT because we are not including DT in the angular frequencies here.
+    # The smoothing is done by multiplication in the Fourier domain.
+    normscales = scales / dt
+
+    for kk in range(0, cfs.shape[0]):
+        F = np.exp(-nt * (normscales[kk] ** 2) * omega ** 2)
+        smooth = np.fft.ifft(F * np.fft.fft(cfs[kk - 1], npad))
+        cfs[kk - 1] = smooth[0:N]
+    # Convolve the coefficients with a moving average smoothing filter across scales.
+    H = 1 / ns * np.ones((ns, 1))
+
+    cfs = conv2(cfs, H)
+    return cfs
+
+def conv2(x, y, mode='same'):
+    """
+    Perform 2D convolution of matrices x and y
+    """
+    return np.rot90(convolve2d(np.rot90(x, 2), np.rot90(y, 2), mode=mode), 2)
+
+def get_wavelet_type(wavelet_type):
+    """
+    return a wavelet object based on the specified wavelet type and associated parameter
+    """
+    # Default parameters for each wavelet type
+    default_params = {
+        'Morlet': 6,
+        'Paul': 4,
+        'DOG': 2,
+        'MexicanHat': 2  # MexicanHat inherits from DOG with m=2
+    }
+
+    wavelet_name = wavelet_type[0]
+
+    # If a second argument is provided, use it; otherwise, use the default value
+    if len(wavelet_type) == 2:
+        param = float(wavelet_type[1])
+    else:
+        param = default_params[wavelet_name]
+
+    # Get the corresponding wavelet object
+    if wavelet_name == 'Morlet':
+        return wavelet.Morlet(param)
+    elif wavelet_name == 'Paul':
+        return wavelet.Paul(param)
+    elif wavelet_name == 'DOG':
+        return wavelet.DOG(param)
+    elif wavelet_name == 'MexicanHat':
+        return wavelet.MexicanHat()  # Uses m=2, so no need for param
+    else:
+        raise logger.error(f"Unknown wavelet type: {wavelet_name}")
+
+def compute_wct_dvv(freqs, tvec, WXamp, Wcoh, delta_t, lag_min=5, coda_cycles=20, mincoh=0.5, maxdt=0.2, 
+            min_nonzero=0.25, freqmin=0.1, freqmax=2.0):
+    """
+    Compute the dv/v values and associated errors from the wavelet transform results.
+    Parameters
+    ----------
+    freqs : numpy.ndarray
+        Frequency values corresponding to the wavelet transform.
+    tvec : numpy.ndarray
+        Time vector.
+    WXamp : numpy.ndarray
+        Amplitude of the cross-wavelet transform.
+    Wcoh : numpy.ndarray
+        Wavelet coherence.
+    delta_t : numpy.ndarray
+        Time delays between signals.
+    lag_min : int, optional
+        Minimum lag in seconds. Default is 5.
+    coda_cycles : int, optional
+        Number of coda cycles to consider. Default is 20.
+    mincoh : float, optional
+        Minimum coherence value for weighting. Default is 0.5.
+    maxdt : float, optional
+        Maximum time delay for weighting. Default is 0.2.
+    min_nonzero : float, optional
+        Minimum percentage of non-zero weights required for valid estimation. Default is 0.25.
+    freqmin : float, optional
+        Minimum frequency for calculation. Default is 0.1 Hz.
+    freqmax : float, optional
+        Maximum frequency for calculation. Default is 2.0 Hz.
+    Returns
+    -------
+    tuple
+        dvv values (percentage), errors (percentage), and weighting function used.
+    """   
+    inx = np.where((freqs >= freqmin) & (freqs <= freqmax))  # Filter frequencies within the specified range
+    dvv, err = np.zeros(len(inx[0])), np.zeros(len(inx[0])) # Initialize dvv and err arrays
+
+    # Weighting function based on WXamp
+    weight_func = np.log(np.abs(WXamp)) / np.log(np.abs(WXamp)).max()
+    zero_idx = np.where((Wcoh < mincoh) | (delta_t > maxdt))
+    wf = (weight_func + abs(np.nanmin(weight_func))) / weight_func.max()
+    wf[zero_idx] = 0
+
+    # Loop through frequency indices for linear regression
+    for ii, ifreq in enumerate(inx[0]):
+        period = 1.0 / freqs[ifreq]
+        lag_max = lag_min + (period * coda_cycles)
+
+        # Coda selection
+        tindex = np.where(((tvec >= -lag_max) & (tvec <= -lag_min)) | ((tvec >= lag_min) & (tvec <= lag_max)))[0]
+
+        if len(tvec) > 2:
+            if not np.any(delta_t[ifreq]):
+                continue
+
+            delta_t[ifreq][tindex] = np.nan_to_num(delta_t[ifreq][tindex])
+            w = wf[ifreq]  # Weighting function for the specific frequency
+            w[~np.isfinite(w)] = 1.0
+
+            # Percentage of non-zero weights
+            nzc_perc = np.count_nonzero(w[tindex]) / len(tindex)
+            if nzc_perc >= min_nonzero:
+                m, em = linear_regression(tvec[tindex], delta_t[ifreq][tindex], w[tindex], intercept_origin=True)
+                dvv[ii], err[ii] = -m, em
+            else:
+                dvv[ii], err[ii] = np.nan, np.nan
+        else:
+            logger.debug('Not enough points to estimate dv/v for WCT')
+    
+    return dvv * 100, err * 100, wf
+
+def xwt(trace_ref, trace_current, fs, ns=3, nt=0.25, vpo=12, freqmin=0.1, freqmax=8.0, nptsfreq=100, wavelet_type=('Morlet',6.)):
+    """
+    Wavelet coherence transform (WCT) on two time series..
+    The WCT finds regions in time frequency space where the two time
+    series co-vary, but do not necessarily have high power.
+    
+    Modified from https://github.com/Qhig/cross-wavelet-transform
+    Parameters
+    ----------
+    trace_ref, trace_current : numpy.ndarray, list
+        Input signals.
+    fs : float
+        Sampling frequency.
+    ns : smoothing parameter. 
+        Default value is 3
+    nt : smoothing parameter. 
+        Default value is 0.25
+    vpo : float,
+        Spacing parameter between discrete scales. Default value is 12.
+        Higher values will result in better scale resolution, but
+        slower calculation and plot.
+        
+    freqmin : float,
+        Smallest frequency
+        Default value is 0.1 Hz
+    freqmax : float,
+        Highest frequency
+        Default value is 8.0 Hz
+    nptsfreq : int,
+        Number of frequency points between freqmin and freqmax.
+        Default value is 100 points
+    wavelet_type: list,
+        Wavelet type and associated parameter.
+        Default Morlet wavelet with a central frequency w0 = 6
+       
+    Returns
+        ----------
+    WXamp : numpy.ndarray
+        Amplitude of the cross-wavelet transform.
+    WXspec : numpy.ndarray
+        Complex cross-wavelet transform, representing both magnitude and phase information.
+    WXangle : numpy.ndarray
+        Phase angles of the cross-wavelet transform, indicating the phase relationship between the input signals.
+    Wcoh : numpy.ndarray
+        Wavelet coherence, representing the degree of correlation between the two signals in time-frequency space.
+    WXdt : numpy.ndarray
+        Time delay between the signals, estimated from the phase angles.
+    freqs : numpy.ndarray
+        Frequencies corresponding to the scales of the wavelet transform.
+    coi : numpy.ndarray
+        Cone of influence, representing the region of the wavelet spectrum where edge effects become significant.
+    
+    """
+    
+    mother = get_wavelet_type(wavelet_type) # mother wavelet class: Morlet, Paul, DOG, MexicanHat param 
+    # nx represent the number of element in the trace_current array
+    nx = np.size(trace_current)
+    x_reference = np.transpose(trace_ref)
+    x_current = np.transpose(trace_current)
+    # Sampling interval
+    dt = 1 / fs
+    # Spacing between discrete scales, the default value is 1/12
+    dj = 1 / vpo 
+    # Number of scales less one, -1 refers to the default value which is J = (log2(N * dt / so)) / dj.
+    J = -1
+    # Smallest scale of the wavelet, default value is 2*dt
+    s0 = 2 * dt  # Smallest scale of the wavelet, default value is 2*dt
+
+    # Creation of the frequency vector that we will use in the continuous wavelet transform 
+    freqlim = np.linspace(freqmax, freqmin, num=nptsfreq, endpoint=True, retstep=False, dtype=None, axis=0)
+
+    # Calculation of the two wavelet transform independently
+    # scales are calculated using the wavelet Fourier wavelength
+    # fft : Normalized fast Fourier transform of the input trace
+    # fftfreqs : Fourier frequencies for the calculated FFT spectrum.
+    ###############################################################################################################
+    cwt_reference, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(x_reference, dt, dj, s0, J, mother, freqs=freqlim)
+    cwt_current, _, _, _, _, _ = wavelet.cwt(x_current, dt, dj, s0, J, mother, freqs=freqlim)
+    ###############################################################################################################
+
+    scales = np.array([[kk] for kk in scales])
+    invscales = np.kron(np.ones((1, nx)), 1 / scales)
+
+    cfs2 = smoothCFS(invscales * abs(cwt_current) ** 2, scales, dt, ns, nt)
+    cfs1 = smoothCFS(invscales * abs(cwt_reference) ** 2, scales, dt, ns, nt)
+    crossCFS = cwt_reference * np.conj(cwt_current)
+    WXamp = abs(crossCFS)
+    # cross-wavelet transform operation with smoothing
+    crossCFS = smoothCFS(invscales * crossCFS, scales, dt, ns, nt)
+    WXspec = crossCFS / (np.sqrt(cfs1) * np.sqrt(cfs2))
+    WXangle = np.angle(WXspec)
+    Wcoh = abs(crossCFS) ** 2 / (cfs1 * cfs2)
+    pp = 2 * np.pi * freqs
+    pp2 = np.array([[kk] for kk in pp])
+    WXdt = WXangle / np.kron(np.ones((1, nx)), pp2)
+
+
+    return WXamp, WXspec, WXangle, Wcoh, WXdt, freqs, coi
+
+def main(loglevel="INFO"):
+    # Reconfigure logger to show the pid number in log records
+    global logger
+    logger = get_logger('msnoise.compute_wct_child', loglevel,
+                        with_pid=True)
+    logger.info('*** Starting: Compute WCT ***')
+
+    db = connect()
+    params = get_params(db)
+    taxis = get_t_axis(db)
+
+    ns = params.wct_ns
+    nt = params.wct_nt 
+    vpo = params.wct_vpo 
+    nptsfreq = params.wct_nptsfreq
+    coda_cycles = params.dtt_codacycles 
+    min_nonzero = params.dvv_min_nonzero
+    wct_norm = params.wct_norm
+    wavelet_type = params.wavelet_type
+    
+    mov_stacks = params.mov_stack
+    goal_sampling_rate = params.cc_sampling_rate
+    lag_min = params.dtt_minlag
+    maxdt = params.dtt_maxdt
+    mincoh = params.dtt_mincoh
+
+    logger.debug('Ready to compute')
+    # Then we compute the jobs
+    filters = get_filters(db, all=False)
+    time.sleep(np.random.random() * 5)
+
+    while is_dtt_next_job(db, flag='T', jobtype='WCT'):
+        # TODO would it be possible to make the next 8 lines in the API ?
+        jobs = get_dtt_next_job(db, flag='T', jobtype='WCT')
+
+        if not len(jobs):
+            # edge case, should only occur when is_next returns true, but
+            # get_next receives no jobs (heavily parallelised calls).
+            time.sleep(np.random.random())
+            continue
+        pair = jobs[0].pair
+        refs, days = zip(*[[job.ref, job.day] for job in jobs])
+
+        logger.info(
+            "There are WCT jobs for some days to recompute for %s" % pair)
+        for f in filters:
+            filterid = int(f.ref)
+            freqmin = f.low
+            freqmax = f.high
+
+            station1, station2 = pair.split(":")
+            if station1 == station2:
+                components_to_compute = params.components_to_compute_single_station
+            else:
+                components_to_compute = params.components_to_compute
+
+            for components in components_to_compute:
+                try:
+                    ref = xr_get_ref(station1, station2, components, filterid, taxis)
+                    ref = ref.CCF.values
+                    if wct_norm:
+                        ori_waveform = (ref/ref.max()) 
+                    else:
+                        ori_waveform = ref
+                except FileNotFoundError as fullpath:
+                    logger.error("FILE DOES NOT EXIST: %s, skipping" % fullpath)
+                    continue
+                if not len(ref):
+                    continue
+
+                for mov_stack in mov_stacks:
+                    dvv_list = []
+                    err_list = []
+                    coh_list = []
+                    data_dates=[]
+                    try:
+                        data = xr_get_ccf(station1, station2, components, filterid, mov_stack, taxis)
+                    except FileNotFoundError as fullpath:
+                        logger.error("FILE DOES NOT EXIST: %s, skipping" % fullpath)
+                        continue
+                    logger.debug("Processing %s:%s f%i m%s %s" % (station1, station2, filterid, mov_stack, components))
+
+                    to_search = pd.to_datetime(days)
+                    to_search = to_search.append(pd.DatetimeIndex([to_search[-1]+pd.Timedelta("1d"),]))
+                    data = data[data.index.floor('d').isin(to_search)]
+                    data = data.dropna()
+
+                    cur = data#.CCF.values
+                    if wct_norm:
+                        new_waveform = (cur/cur.max()) 
+                    else:
+                        new_waveform = cur
+
+                    for date, row in new_waveform.iterrows():
+                        WXamp, WXspec, WXangle, Wcoh, WXdt, freqs, coi = xwt(ori_waveform, row.values, goal_sampling_rate, int(ns), int(nt), int(vpo), freqmin, freqmax, int(nptsfreq), wavelet_type)
+                        dvv, err, wf = compute_wct_dvv(freqs, taxis, WXamp, Wcoh, WXdt, lag_min=int(lag_min), coda_cycles=coda_cycles, mincoh=mincoh, maxdt=maxdt, min_nonzero=min_nonzero, freqmin=freqmin, freqmax=freqmax)
+                        coh = get_avgcoh(freqs, taxis, Wcoh, freqmin, freqmax, lag_min=int(lag_min), coda_cycles=coda_cycles)
+                        dvv_list.append(dvv)
+                        err_list.append(err)
+                        coh_list.append(coh)
+                        data_dates.append(date)
+
+                    if len(dvv_list) > 0:#1:
+                        inx = np.where((freqs >= freqmin) & (freqs <= freqmax))
+
+                        dvv_df = pd.DataFrame(dvv_list, columns=freqs[inx], index=data_dates)
+                        err_df = pd.DataFrame(err_list, columns=freqs[inx], index=data_dates)
+                        coh_df = pd.DataFrame(coh_list, columns=freqs[inx], index=data_dates)
+
+                        # Saving
+                        xr_save_wct(station1, station2, components, filterid, mov_stack, taxis, dvv_df, err_df, coh_df)
+
+                        del dvv_df, err_df, coh_df
+                    del cur
+
+        massive_update_job(db, jobs, "D")
+
+    logger.info('*** Finished: Compute WCT ***')
+
+if __name__ == "__main__":
+    main()
diff --git a/msnoise/scripts/msnoise.py b/msnoise/scripts/msnoise.py
index 3c4cc44..5839f40 100644
--- a/msnoise/scripts/msnoise.py
+++ b/msnoise/scripts/msnoise.py
@@ -200,7 +200,7 @@ def info_jobs(db):
 
     jobtypes = {}
     jobtypes["QC"] = ["PSD", "PSD2HDF", "HDF2RMS"]
-    jobtypes["CC"] = ["CC", "STACK", "MWCS", "DTT", "DVV"]
+    jobtypes["CC"] = ["CC", "STACK", "MWCS", "DTT", "DVV", "WCT"]
 
     click.echo("Jobs:")
     for category in ["QC", "CC"]:
@@ -1238,6 +1238,27 @@ def dvv_compute_dvv(ctx):
         for p in processes:
             p.join()
 
+@dvv.command(name='compute_wct')
+@click.pass_context
+def dvv_compute_wct(ctx):
+    """Computes the wavelet dv/v jobs based on the new STACK data"""
+    from ..s08compute_wct import main
+    threads = ctx.obj['MSNOISE_threads']
+    delay = ctx.obj['MSNOISE_threadsdelay']
+    loglevel = ctx.obj['MSNOISE_verbosity']
+    if threads == 1:
+        main(loglevel=loglevel)
+    else:
+        from multiprocessing import Process
+        processes = []
+        for i in range(threads):
+            p = Process(target=main, kwargs={"loglevel": loglevel})
+            p.start()
+            processes.append(p)
+            time.sleep(delay)
+        for p in processes:
+            p.join()
+
 @dvv.group(name="plot")
 def dvv_plot():
     """Commands to trigger different plots"""
@@ -1319,7 +1340,35 @@ def dvv_plot_dtt(ctx, sta1, sta2, filterid, day, comp, mov_stack, show, outfile)
         from ..plots.dtt import main
     main(sta1, sta2, filterid, comp, day, mov_stack, show, outfile, loglevel=loglevel)
 
-
+@dvv_plot.command(name="wct")
+@click.option('-f', '--filterid', default=1, help='Filter ID')
+@click.option('-c', '--comp', default="ZZ", help='Components (ZZ, ZE, NZ, 1E,...). Defaults to ZZ')
+@click.option('-m', '--mov_stack', default=0, help='Plot specific mov stacks')
+@click.option('-p', '--pair', default=None, help='Plot a specific pair',
+              multiple=True)
+@click.option('-A', '--all', help='Show the ALL line?', is_flag=True)
+@click.option('-e', '--end', default="2100-01-01", help='Plot until which date? (default=2100-01-01 or enddate)')
+@click.option('-b', '--begin',default="1970-01-01",  help="Plot from which date, can be relative to the endate ('-100'days)?(default=1970-01-01 or startdate)")
+@click.option('-v', '--visualize',default="dvv",  help="Which plot : wavelet 'dvv' heat map, wavelet 'coh'erence heat map, dv/v 'curve' with coherence color?", type=str)
+@click.option('-r', '--ranges',default="[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]",  help="With visualize = 'curve', which frequency ranges to use?", type=str)
+@click.option('-s', '--show', help='Show interactively?',
+              default=True, type=bool)
+@click.option('-o', '--outfile', help='Output filename (?=auto). Defaults to PNG format, but can be anything matplotlib outputs, e.g. ?.pdf will save to PDF with an automatic file naming.',
+              default=None, type=str)
+@click.pass_context
+def dvv_plot_wct(ctx, mov_stack, comp, filterid, pair, all, begin, end, visualize,ranges, show,  outfile):
+    """Plots the dv/v (parses the dt/t results)
+    Individual pairs can be plotted extra using the -p flag one or more times.
+    Example: msnoise plot dvv -p ID_KWUI_ID_POSI
+    Example: msnoise plot dvv -p ID_KWUI_ID_POSI -p ID_KWUI_ID_TRWI
+    Remember to order stations alphabetically !
+    """
+    loglevel = ctx.obj['MSNOISE_verbosity']
+    if ctx.obj['MSNOISE_custom']:
+        from wct_dvv import main # NOQA
+    else:
+        from ..plots.wct_dvv import main
+    main(mov_stack, comp, filterid, pair, all, begin, end, visualize, ranges, show, outfile, loglevel=loglevel)
 
 @dvv_plot.command(name="timing")
 @click.option('-f', '--filterid', default=1, help='Filter ID')
diff --git a/msnoise/test/tests.py b/msnoise/test/tests.py
index 9179eee..da2a29f 100644
--- a/msnoise/test/tests.py
+++ b/msnoise/test/tests.py
@@ -441,6 +441,22 @@ def test_031_instrument_response(self):
     #     from ..s03compute_no_rotation import main
     #     main()
 
+    def test_032_wct(self):
+        from ..api import connect, read_db_inifile
+        from sqlalchemy import text
+        db = connect()
+        dbini = read_db_inifile()
+        prefix = (dbini.prefix + '_') if dbini.prefix != '' else ''
+        db.execute(text("INSERT INTO {prefix}jobs (pair, day, jobtype, flag) "
+                       "SELECT pair, day, '{right_type}', 'T' FROM {prefix}jobs "
+                       "WHERE jobtype='{left_type}' AND flag='D';"
+                       .format(prefix=prefix, right_type="WCT", left_type="STACK")))
+        db.commit()
+
+        from ..s08compute_wct import main
+        main()
+        db.close()
+    
     # PLOTS
 
     def test_100_plot_cctfime(self):
@@ -542,6 +558,13 @@ def test_105_db_dump(self):
         self.assertTrue(os.path.isfile("jobs.csv"))
         self.assertTrue(os.path.isfile("data_availability.csv"))
 
+    def test_106_plot_wct(self):
+        from ..plots.wct_dvv import main
+        main(filterid=1, components="ZZ", show=False, outfile="?.png")
+        fn = "wct ZZ-f1-dvv.png"
+        self.assertTrue(os.path.isfile(fn),
+                        msg="%s doesn't exist" % fn)
+  
     ### A few click CLI interface tests
 
     def test_201_config_get_unknown_param(self):