Skip to content

Commit

Permalink
Adding Wavelet-based DVV to the default workflow (#362)
Browse files Browse the repository at this point in the history
* Update default.csv: wct param

* Update msnoise.py: WCT job

* Create s08compute_wct.py

* Update s08compute_wct.py: add wavelet_type param and def

* Update default.csv: add wavelet_type

* Update msnoise.py

* Update default.csv

* Update s08compute_wct.py

* Add files via upload

* Update tests.py: add compute_wct and plot tests

* Update msnoise.py with plot wct

* Update wct_dvv.py

* Update s08compute_wct.py

* Update wct_dvv.py

* Update msnoise.py

* Update wct_dvv.py: mov_stack/rolling

if mov_stack in params list use it, else rolling = mov_stack and applyied to the plot

* Update api.py: xr_save_wct in api

* Update s08compute_wct.py: move xr_save_wct in api

I don't know why tests don't build for PRs, but it's OK, we'll see later :-) (unsafe comment, don't remember it)
  • Loading branch information
LaureBrenot authored Sep 18, 2024
1 parent 8afdaff commit 34dea83
Show file tree
Hide file tree
Showing 6 changed files with 887 additions and 3 deletions.
55 changes: 55 additions & 0 deletions msnoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,3 +2556,58 @@ def compute_dvv(session, filterid, mov_stack, pairs=None, components=None, param
stats[(c, "trimmed_mean")], stats[(c, "trimmed_std")] = trim(all, c, kwargs.get("limits", None))

return stats.sort_index(axis=1)

def xr_save_wct(station1, station2, components, filterid, mov_stack, taxis, dvv_df, err_df, coh_df):
"""
Save the Wavelet Coherence Transform (WCT) results as a NetCDF file.
Parameters
----------
station1 : str
The first station in the pair.
station2 : str
The second station in the pair.
components : str
The components (e.g., Z, N, E) being analyzed.
filterid : int
Filter ID used in the analysis.
mov_stack : tuple
Tuple of (start, end) representing the moving stack window.
taxis : array-like
Time axis corresponding to the WCT data.
dvv_df : pandas.DataFrame
DataFrame containing dvv data (2D).
err_df : pandas.DataFrame
DataFrame containing err data (2D).
coh_df : pandas.DataFrame
DataFrame containing coh data (2D).
Returns
-------
None
"""
# Construct the file path
fn = os.path.join("WCT", f"{filterid:02d}", f"{mov_stack[0]}_{mov_stack[1]}",
components, f"{station1}_{station2}.nc")

# Ensure the directory exists
os.makedirs(os.path.dirname(fn), exist_ok=True)

# Convert DataFrames to xarray.DataArrays
dvv_da = xr.DataArray(dvv_df.values, coords=[dvv_df.index, dvv_df.columns], dims=['times', 'frequency'])
err_da = xr.DataArray(err_df.values, coords=[err_df.index, err_df.columns], dims=['times', 'frequency'])
coh_da = xr.DataArray(coh_df.values, coords=[coh_df.index, coh_df.columns], dims=['times', 'frequency'])

# Combine into a single xarray.Dataset
ds = xr.Dataset({
'dvv': dvv_da,
'err': err_da,
'coh': coh_da
})

# Save the dataset to a NetCDF file
ds.to_netcdf(fn)

logger.debug(f"Saved WCT data to {fn}")
# Clean up
del dvv_da, err_da, coh_da, ds
10 changes: 9 additions & 1 deletion msnoise/default.csv
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ dtt_sides,both,Which sides to use,str,both/left/right,,
dtt_mincoh,0.65,"Minimum coherence on dt measurement, MWCS points with values lower than that will not be used in the WLS, [0:1]",float,,,
dtt_maxerr,0.1,"Maximum error on dt measurement, MWCS points with values larger than that will not be used in the WLS [0:1]",float,,,
dtt_maxdt,0.1,"Maximum dt values, MWCS points with values larger than that will not be used in the WLS (in seconds)",float,,,
wct_ns,5,"smoothing parameter in frequency",float,,,
wct_nt,5,"smoothing parameter in time",float,,,
wct_vpo,20,"spacing param between discrete scales",float,,,
wct_nptsfreq,300,"number of freq points between min and max",float,,,
dtt_codacycles,20,"number of cycles of period (1/freq) between lag_min and lag_max",int,,,
dvv_min_nonzero,0.25,"percentage of data points with non-zero weighting required for regression otherwise nan (0 to 1)",float,,,
wct_norm,Y,"Is the REF and CCF are normalized before computing wavelet? [Y]/N",bool,Y/N,,
wavelet_type,"('Morlet',6.)","Wavelet type and optional associated parameter",eval,Morlet/Paul/DOG/MexicanHat,,
plugins,,Comma separated list of plugin names. Plugins names should be importable Python modules.,str,,,
hpc,N,Is MSNoise going to run on an HPC?,bool,Y/N,,
stretching_max,0.01,"Maximum stretching coefficient, e.g. 0.5 = 50%, 0.01 = 1%",float,,,
Expand All @@ -62,4 +70,4 @@ qc_ppsd_period_step_octaves,0.0125,Step length on frequency axis in fraction of
qc_ppsd_period_limits,"(0.01,100)","Set custom lower and upper end of period range (e.g. ``(0.01, 100)`` seconds). The specified lower end of period range will be set as the central period of the first bin (geometric mean of left/right edges of smoothing interval). At the upper end of the specified period range, no more additional bins will be added after the bin whose center frequency exceeds the given upper end for the first time.",eval,,,
qc_ppsd_db_bins,"(-200, -50, 1.)",Specify the lower and upper boundary and the width of the db bins. The bin width might get adjusted to fit a number of equally spaced bins in between the given boundaries.,eval,,,
qc_rms_frequency_ranges,"[(1.0, 20.0), (4.0, 14.0), (4.0, 40.0), (4.0, 9.0)]",Specify the frequency bounds (in Hz) to compute the RMS from PSDs,eval,,,
qc_rms_type,DISP,What units do you want for the exported RMS,str,DISP/VEL/ACC,,
qc_rms_type,DISP,What units do you want for the exported RMS,str,DISP/VEL/ACC,,
286 changes: 286 additions & 0 deletions msnoise/plots/wct_dvv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
"""
This plot shows the final output of MSNoise using the wavelet.
Example:
``msnoise cc dvv plot wct``
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D
import matplotlib.dates as mdates
import pandas as pd
from ..api import *
from datetime import datetime, timedelta

def plot_dvv_heatmap(data_type, dvv_df, pair, rolling, start, end, low, high, logger, mincoh=0.5):
# Extracting relevant data from dvv_df
dvv_df = dvv_df.loc[start:end]
if dvv_df.empty:
logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.")
return
rolling_window = int(rolling)

dvv_freq = dvv_df['dvv']
coh_freq = dvv_df['coh']

dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean()
coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean()

fig, ax = plt.subplots(figsize=(16, 10))
# Scatter plot of dv/v data
#norm1 = plt.Normalize(vmin=np.min(dvv_freq.T), vmax=np.max(dvv_freq.T))

if data_type == 'dvv':
low_per = np.nanpercentile(dvv_freq, 1)
high_per = np.nanpercentile(dvv_freq, 99)

ax.pcolormesh(np.asarray(dvv_freq.index), np.asarray(dvv_freq.columns), dvv_freq.T,
cmap=mpl.cm.seismic, edgecolors='none', vmin=low_per, vmax=high_per)
save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_dvv_heatmap"
color_bar_label = 'dv/v (%)'

elif data_type == 'coh':
ax.pcolormesh(np.asarray(coh_freq.index), np.asarray(coh_freq.columns), coh_freq.T,
cmap='RdYlGn', edgecolors='none', vmin=mincoh, vmax=1)
save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_coh_heatmap"
color_bar_label = 'Coherence value'
else:
logger.error("Unknown data type: %s, write 'dvv' or 'coh'? " % data_type)
return None, None, None

#if current_config.get('plot_event', False):
# plot_events(ax, current_config['event_list'], start, end)

ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end))

ax.set_ylabel('Frequency (Hz)', fontsize=18)
ax.set_title(save_name, fontsize=22)

cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02)
cbar1.set_label(color_bar_label, fontsize=18)
#norm1 = Normalize(vmin=0, vmax=1)

ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5)
ax.xaxis.set_minor_locator(mdates.MonthLocator())
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
fig.autofmt_xdate()

# Adjust the layout if necessary
fig.subplots_adjust(right=0.85)
fig.tight_layout()
return fig, save_name

def plot_dvv_scatter(dvv_df, pair, rolling, start, end, ranges, logger):
# Extracting relevant data from dvv_df
dvv_df = dvv_df.loc[start:end]
if dvv_df.empty:
logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.")
return
rolling_window = int(rolling)

color = ['Blues', 'Reds','Greens','Greys'] #'Purples'
color2 = ['blue', 'red', 'green', 'grey'] # Colors for different frequency ranges
freq_names = []
legend_handles = []
ranges_list = [list(map(float, r.strip().strip('[]').split(','))) for r in ranges.split('], [')]

fig, ax = plt.subplots(figsize=(16, 10))
# Loop through the frequency ranges specified in current_config
for i, freqrange in enumerate(ranges_list):#[[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]):#, [0.5, 2.0]]):
freq_name = f"{freqrange[0]}-{freqrange[1]} Hz"
freq_names.append(freq_name)

freqs = np.asarray(dvv_df['dvv'].columns)
filtered_freqs = freqs[(freqs >= freqrange[0]) & (freqs <= freqrange[1])].tolist()
dvv_freq = dvv_df['dvv'][filtered_freqs]
coh_freq = dvv_df['coh'][filtered_freqs]

dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean()
coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean()

# Scatter plot of dv/v data
norm1 = plt.Normalize(vmin=0, vmax=1)
ax.scatter([0,1], [0,1], c=[0,1], cmap=color[-1])

sc = ax.scatter(dvv_freq.index, dvv_freq.mean(axis=1), c=coh_freq.mean(axis=1), cmap=color[i], norm=norm1, label=freq_name)
legend_handles.append(Line2D([0], [0], marker='o', color=color2[i], markerfacecolor=color2[i], markersize=10, label=freq_name))

#if current_config.get('plot_event', False):
# plot_events(ax, current_config['event_list'], start, end)

ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end))
#if current_config.get('same_dvv_scale', False):
# ax.set_ylim(current_config['dvv_min'], current_config['dvv_max'])

ax.set_ylabel('dv/v (%)', fontsize=18)
ax.set_title(f"{pair[0]} dv/v scatter plot", fontsize=22)

legend1 = ax.legend(handles=legend_handles, fontsize=22, loc='upper left')
ax.add_artist(legend1)

cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02)
cbar1.set_label('Coherence value \n(darkness of the point)', fontsize=18)
norm1 = Normalize(vmin=0, vmax=1)

ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5)
ax.xaxis.set_minor_locator(mdates.MonthLocator())
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
fig.autofmt_xdate()

# Adjust the layout if necessary
fig.subplots_adjust(right=0.85)
fig.tight_layout()

return fig, f"{pair[0]}_{ np.min(ranges_list)}_{np.max(ranges_list)}_Hz_m{rolling_window}_wctscatter"


def save_figure(fig, filename, logger, mov_stack, components,filterid, visualize, plot_all_period=False, start=None, end=None, outfile=None):
fig_path = os.path.join('Figures' if plot_all_period else 'Figures/Zooms')
create_folder(fig_path, logger)
mov_stack= mov_stack[0]
if start and end:
filename = f'{filename}_{str(start)[:10]}_{str(end)[:10]}'
filepath = os.path.join(fig_path, f'{filename}.png')

if outfile:
if outfile.startswith("?"):
if len(mov_stack) == 1:
outfile = outfile.replace('?', '%s-f%i-m%s_%s-%s' % (components,
filterid,
mov_stack[0],
mov_stack[1],
visualize))
else:
outfile = outfile.replace('?', '%s-f%i-%s' % (components,
filterid,
visualize))
filepath = "wct " + outfile
logger.info("output to: %s" % outfile)

fig.savefig(filepath, dpi=300, bbox_inches='tight', transparent=True)

def create_folder(folder_path, logger):
try:
os.makedirs(folder_path)
logger.info(f"Folder '{folder_path}' created successfully.")
except FileExistsError:
pass

def xr_get_wct_pair(pair, components, filterid, mov_stack, logger):
fn = os.path.join("WCT", "%02i" % filterid,
"%s_%s" % (mov_stack[0], mov_stack[1]),
"%s" % components, "%s.nc" % (pair))
if not os.path.isfile(fn):
logger.error("FILE DOES NOT EXIST: %s, skipping" % fn)

data = xr_create_or_open(fn, name="WCT")
data = data.to_dataframe().unstack(level='frequency')
return data

def xr_get_wct(components, filterid, mov_stack, logger):
fn = os.path.join("WCT", "%02i" % filterid,
"%s_%s" % (mov_stack[0], mov_stack[1]),
"%s" % components, "*.nc" )
matching_files = glob.glob(fn)
if not matching_files:
logger.error(f"No files found matching pattern: {fn}")

all_wct = []
for fil in matching_files:
data = xr_create_or_open(fil, name="WCT")
data = data.to_dataframe().unstack(level='frequency')
all_wct.append(data)
combined = pd.concat(all_wct, axis=0)
dvv = combined.groupby(combined.index).mean()
return dvv

def validate_and_adjust_date(date_string, end_date, logger):
try:
start_date = datetime.strptime(date_string, '%Y-%m-%d')
except ValueError:
try:
days_delta = int(date_string)
start_date = datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=days_delta)
except ValueError:
logger.error(f"Invalid start string: {date_string}")
return None
return start_date

def main(mov_stackid=None, components='ZZ', filterid=1,
pairs=[], showALL=False, start="1970-01-01", end="2100-01-01", visualize='dvv', ranges="[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]", show=True,outfile=None, loglevel="INFO"):
logger = get_logger('msnoise.cc_dvv_plot_dvv', loglevel,
with_pid=True)
db = connect()
params = get_params(db)
mincoh = params.dtt_mincoh

# Check start and end dates
if start == "1970-01-01":
start= params.startdate
else:
start = validate_and_adjust_date(start, end, logger)
if end == "2100-01-01":
end = params.enddate

# TODO clearer mov_stackid to additionnal rolling
if mov_stackid and mov_stackid != "": #if mov_stackid given
try:
mov_stack = params.mov_stack[mov_stackid - 1]
if mov_stack in params.mov_stack: # Check if mov_stack is in params.mov_stack
mov_stacks = [mov_stack, ]
rolling = 1
else:
rolling = mov_stackid # Assign mov_stack to rolling
except:
mov_stack = params.mov_stack[0]
if mov_stack in params.mov_stack: # Check if mov_stack is in params.mov_stack new format
mov_stacks = [mov_stack, ]
rolling = 1 # Keeping the mov_stack result
else:
rolling = mov_stack # Assign mov_stack to rolling
else:
mov_stacks = params.mov_stack
rolling = int(params.mov_stack[0][0][0])

if components.count(","):
components = components.split(",")
else:
components = [components, ]

filter = get_filters(db, ref=filterid)
low = float(filter.low)
high = float(filter.high)

for i, mov_stack in enumerate(mov_stacks):
for comps in components:
# Get the data
if not pairs:
dvv = xr_get_wct(comps, filterid, mov_stack, logger)
pairs = ["all stations",]
else:
try:
dvv = xr_get_wct_pair(pairs, comps, filterid, mov_stack, logger)
except FileNotFoundError as fullpath:
logger.error("FILE DOES NOT EXIST: %s, skipping" % fullpath)
continue
# Plotting
if visualize == 'dvv':
fig, savename = plot_dvv_heatmap('dvv', dvv, pairs, rolling, start, end, low, high, logger, mincoh)
elif visualize == 'coh':
fig, savename = plot_dvv_heatmap('coh', dvv, pairs, rolling, start, end, low, high, logger, mincoh)
elif visualize == 'curve':
fig, savename = plot_dvv_scatter(dvv, pairs, rolling, start, end, ranges, logger)
else:
looger.error("PLOT TYPE DOES NOT EXIST: %s" % visualize)
# Save and show the figure
save_figure(fig, savename, logger, mov_stacks, comps, filterid, visualize, plot_all_period=False, start=start, end=end, outfile=outfile)
if show:
plt.show()

if __name__ == "__main__":
main()

Loading

0 comments on commit 34dea83

Please sign in to comment.