-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Wavelet-based DVV to the default workflow (#362)
* Update default.csv: wct param * Update msnoise.py: WCT job * Create s08compute_wct.py * Update s08compute_wct.py: add wavelet_type param and def * Update default.csv: add wavelet_type * Update msnoise.py * Update default.csv * Update s08compute_wct.py * Add files via upload * Update tests.py: add compute_wct and plot tests * Update msnoise.py with plot wct * Update wct_dvv.py * Update s08compute_wct.py * Update wct_dvv.py * Update msnoise.py * Update wct_dvv.py: mov_stack/rolling if mov_stack in params list use it, else rolling = mov_stack and applyied to the plot * Update api.py: xr_save_wct in api * Update s08compute_wct.py: move xr_save_wct in api I don't know why tests don't build for PRs, but it's OK, we'll see later :-) (unsafe comment, don't remember it)
- Loading branch information
1 parent
8afdaff
commit 34dea83
Showing
6 changed files
with
887 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
""" | ||
This plot shows the final output of MSNoise using the wavelet. | ||
Example: | ||
``msnoise cc dvv plot wct`` | ||
""" | ||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
from matplotlib.colors import Normalize | ||
from matplotlib.lines import Line2D | ||
import matplotlib.dates as mdates | ||
import pandas as pd | ||
from ..api import * | ||
from datetime import datetime, timedelta | ||
|
||
def plot_dvv_heatmap(data_type, dvv_df, pair, rolling, start, end, low, high, logger, mincoh=0.5): | ||
# Extracting relevant data from dvv_df | ||
dvv_df = dvv_df.loc[start:end] | ||
if dvv_df.empty: | ||
logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.") | ||
return | ||
rolling_window = int(rolling) | ||
|
||
dvv_freq = dvv_df['dvv'] | ||
coh_freq = dvv_df['coh'] | ||
|
||
dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean() | ||
coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean() | ||
|
||
fig, ax = plt.subplots(figsize=(16, 10)) | ||
# Scatter plot of dv/v data | ||
#norm1 = plt.Normalize(vmin=np.min(dvv_freq.T), vmax=np.max(dvv_freq.T)) | ||
|
||
if data_type == 'dvv': | ||
low_per = np.nanpercentile(dvv_freq, 1) | ||
high_per = np.nanpercentile(dvv_freq, 99) | ||
|
||
ax.pcolormesh(np.asarray(dvv_freq.index), np.asarray(dvv_freq.columns), dvv_freq.T, | ||
cmap=mpl.cm.seismic, edgecolors='none', vmin=low_per, vmax=high_per) | ||
save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_dvv_heatmap" | ||
color_bar_label = 'dv/v (%)' | ||
|
||
elif data_type == 'coh': | ||
ax.pcolormesh(np.asarray(coh_freq.index), np.asarray(coh_freq.columns), coh_freq.T, | ||
cmap='RdYlGn', edgecolors='none', vmin=mincoh, vmax=1) | ||
save_name = f"{pair[0]}_{low}_{high}_Hz_m{rolling_window}_coh_heatmap" | ||
color_bar_label = 'Coherence value' | ||
else: | ||
logger.error("Unknown data type: %s, write 'dvv' or 'coh'? " % data_type) | ||
return None, None, None | ||
|
||
#if current_config.get('plot_event', False): | ||
# plot_events(ax, current_config['event_list'], start, end) | ||
|
||
ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end)) | ||
|
||
ax.set_ylabel('Frequency (Hz)', fontsize=18) | ||
ax.set_title(save_name, fontsize=22) | ||
|
||
cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02) | ||
cbar1.set_label(color_bar_label, fontsize=18) | ||
#norm1 = Normalize(vmin=0, vmax=1) | ||
|
||
ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5) | ||
ax.xaxis.set_minor_locator(mdates.MonthLocator()) | ||
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m")) | ||
fig.autofmt_xdate() | ||
|
||
# Adjust the layout if necessary | ||
fig.subplots_adjust(right=0.85) | ||
fig.tight_layout() | ||
return fig, save_name | ||
|
||
def plot_dvv_scatter(dvv_df, pair, rolling, start, end, ranges, logger): | ||
# Extracting relevant data from dvv_df | ||
dvv_df = dvv_df.loc[start:end] | ||
if dvv_df.empty: | ||
logger.error(f"No data available for {pair} between {start} and {end}. Exiting function.") | ||
return | ||
rolling_window = int(rolling) | ||
|
||
color = ['Blues', 'Reds','Greens','Greys'] #'Purples' | ||
color2 = ['blue', 'red', 'green', 'grey'] # Colors for different frequency ranges | ||
freq_names = [] | ||
legend_handles = [] | ||
ranges_list = [list(map(float, r.strip().strip('[]').split(','))) for r in ranges.split('], [')] | ||
|
||
fig, ax = plt.subplots(figsize=(16, 10)) | ||
# Loop through the frequency ranges specified in current_config | ||
for i, freqrange in enumerate(ranges_list):#[[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]):#, [0.5, 2.0]]): | ||
freq_name = f"{freqrange[0]}-{freqrange[1]} Hz" | ||
freq_names.append(freq_name) | ||
|
||
freqs = np.asarray(dvv_df['dvv'].columns) | ||
filtered_freqs = freqs[(freqs >= freqrange[0]) & (freqs <= freqrange[1])].tolist() | ||
dvv_freq = dvv_df['dvv'][filtered_freqs] | ||
coh_freq = dvv_df['coh'][filtered_freqs] | ||
|
||
dvv_freq = dvv_freq.rolling(window=rolling_window, min_periods=1).mean() | ||
coh_freq = coh_freq.rolling(window=rolling_window, min_periods=1).mean() | ||
|
||
# Scatter plot of dv/v data | ||
norm1 = plt.Normalize(vmin=0, vmax=1) | ||
ax.scatter([0,1], [0,1], c=[0,1], cmap=color[-1]) | ||
|
||
sc = ax.scatter(dvv_freq.index, dvv_freq.mean(axis=1), c=coh_freq.mean(axis=1), cmap=color[i], norm=norm1, label=freq_name) | ||
legend_handles.append(Line2D([0], [0], marker='o', color=color2[i], markerfacecolor=color2[i], markersize=10, label=freq_name)) | ||
|
||
#if current_config.get('plot_event', False): | ||
# plot_events(ax, current_config['event_list'], start, end) | ||
|
||
ax.set_xlim(pd.to_datetime(start), pd.to_datetime(end)) | ||
#if current_config.get('same_dvv_scale', False): | ||
# ax.set_ylim(current_config['dvv_min'], current_config['dvv_max']) | ||
|
||
ax.set_ylabel('dv/v (%)', fontsize=18) | ||
ax.set_title(f"{pair[0]} dv/v scatter plot", fontsize=22) | ||
|
||
legend1 = ax.legend(handles=legend_handles, fontsize=22, loc='upper left') | ||
ax.add_artist(legend1) | ||
|
||
cbar1 = plt.colorbar(ax.collections[0], ax=ax, pad=0.02) | ||
cbar1.set_label('Coherence value \n(darkness of the point)', fontsize=18) | ||
norm1 = Normalize(vmin=0, vmax=1) | ||
|
||
ax.tick_params(axis='both', which='both', labelsize=16, width=2, length=5) | ||
ax.xaxis.set_minor_locator(mdates.MonthLocator()) | ||
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m")) | ||
fig.autofmt_xdate() | ||
|
||
# Adjust the layout if necessary | ||
fig.subplots_adjust(right=0.85) | ||
fig.tight_layout() | ||
|
||
return fig, f"{pair[0]}_{ np.min(ranges_list)}_{np.max(ranges_list)}_Hz_m{rolling_window}_wctscatter" | ||
|
||
|
||
def save_figure(fig, filename, logger, mov_stack, components,filterid, visualize, plot_all_period=False, start=None, end=None, outfile=None): | ||
fig_path = os.path.join('Figures' if plot_all_period else 'Figures/Zooms') | ||
create_folder(fig_path, logger) | ||
mov_stack= mov_stack[0] | ||
if start and end: | ||
filename = f'{filename}_{str(start)[:10]}_{str(end)[:10]}' | ||
filepath = os.path.join(fig_path, f'{filename}.png') | ||
|
||
if outfile: | ||
if outfile.startswith("?"): | ||
if len(mov_stack) == 1: | ||
outfile = outfile.replace('?', '%s-f%i-m%s_%s-%s' % (components, | ||
filterid, | ||
mov_stack[0], | ||
mov_stack[1], | ||
visualize)) | ||
else: | ||
outfile = outfile.replace('?', '%s-f%i-%s' % (components, | ||
filterid, | ||
visualize)) | ||
filepath = "wct " + outfile | ||
logger.info("output to: %s" % outfile) | ||
|
||
fig.savefig(filepath, dpi=300, bbox_inches='tight', transparent=True) | ||
|
||
def create_folder(folder_path, logger): | ||
try: | ||
os.makedirs(folder_path) | ||
logger.info(f"Folder '{folder_path}' created successfully.") | ||
except FileExistsError: | ||
pass | ||
|
||
def xr_get_wct_pair(pair, components, filterid, mov_stack, logger): | ||
fn = os.path.join("WCT", "%02i" % filterid, | ||
"%s_%s" % (mov_stack[0], mov_stack[1]), | ||
"%s" % components, "%s.nc" % (pair)) | ||
if not os.path.isfile(fn): | ||
logger.error("FILE DOES NOT EXIST: %s, skipping" % fn) | ||
|
||
data = xr_create_or_open(fn, name="WCT") | ||
data = data.to_dataframe().unstack(level='frequency') | ||
return data | ||
|
||
def xr_get_wct(components, filterid, mov_stack, logger): | ||
fn = os.path.join("WCT", "%02i" % filterid, | ||
"%s_%s" % (mov_stack[0], mov_stack[1]), | ||
"%s" % components, "*.nc" ) | ||
matching_files = glob.glob(fn) | ||
if not matching_files: | ||
logger.error(f"No files found matching pattern: {fn}") | ||
|
||
all_wct = [] | ||
for fil in matching_files: | ||
data = xr_create_or_open(fil, name="WCT") | ||
data = data.to_dataframe().unstack(level='frequency') | ||
all_wct.append(data) | ||
combined = pd.concat(all_wct, axis=0) | ||
dvv = combined.groupby(combined.index).mean() | ||
return dvv | ||
|
||
def validate_and_adjust_date(date_string, end_date, logger): | ||
try: | ||
start_date = datetime.strptime(date_string, '%Y-%m-%d') | ||
except ValueError: | ||
try: | ||
days_delta = int(date_string) | ||
start_date = datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=days_delta) | ||
except ValueError: | ||
logger.error(f"Invalid start string: {date_string}") | ||
return None | ||
return start_date | ||
|
||
def main(mov_stackid=None, components='ZZ', filterid=1, | ||
pairs=[], showALL=False, start="1970-01-01", end="2100-01-01", visualize='dvv', ranges="[0.5, 1.0], [1.0, 2.0], [2.0, 4.0]", show=True,outfile=None, loglevel="INFO"): | ||
logger = get_logger('msnoise.cc_dvv_plot_dvv', loglevel, | ||
with_pid=True) | ||
db = connect() | ||
params = get_params(db) | ||
mincoh = params.dtt_mincoh | ||
|
||
# Check start and end dates | ||
if start == "1970-01-01": | ||
start= params.startdate | ||
else: | ||
start = validate_and_adjust_date(start, end, logger) | ||
if end == "2100-01-01": | ||
end = params.enddate | ||
|
||
# TODO clearer mov_stackid to additionnal rolling | ||
if mov_stackid and mov_stackid != "": #if mov_stackid given | ||
try: | ||
mov_stack = params.mov_stack[mov_stackid - 1] | ||
if mov_stack in params.mov_stack: # Check if mov_stack is in params.mov_stack | ||
mov_stacks = [mov_stack, ] | ||
rolling = 1 | ||
else: | ||
rolling = mov_stackid # Assign mov_stack to rolling | ||
except: | ||
mov_stack = params.mov_stack[0] | ||
if mov_stack in params.mov_stack: # Check if mov_stack is in params.mov_stack new format | ||
mov_stacks = [mov_stack, ] | ||
rolling = 1 # Keeping the mov_stack result | ||
else: | ||
rolling = mov_stack # Assign mov_stack to rolling | ||
else: | ||
mov_stacks = params.mov_stack | ||
rolling = int(params.mov_stack[0][0][0]) | ||
|
||
if components.count(","): | ||
components = components.split(",") | ||
else: | ||
components = [components, ] | ||
|
||
filter = get_filters(db, ref=filterid) | ||
low = float(filter.low) | ||
high = float(filter.high) | ||
|
||
for i, mov_stack in enumerate(mov_stacks): | ||
for comps in components: | ||
# Get the data | ||
if not pairs: | ||
dvv = xr_get_wct(comps, filterid, mov_stack, logger) | ||
pairs = ["all stations",] | ||
else: | ||
try: | ||
dvv = xr_get_wct_pair(pairs, comps, filterid, mov_stack, logger) | ||
except FileNotFoundError as fullpath: | ||
logger.error("FILE DOES NOT EXIST: %s, skipping" % fullpath) | ||
continue | ||
# Plotting | ||
if visualize == 'dvv': | ||
fig, savename = plot_dvv_heatmap('dvv', dvv, pairs, rolling, start, end, low, high, logger, mincoh) | ||
elif visualize == 'coh': | ||
fig, savename = plot_dvv_heatmap('coh', dvv, pairs, rolling, start, end, low, high, logger, mincoh) | ||
elif visualize == 'curve': | ||
fig, savename = plot_dvv_scatter(dvv, pairs, rolling, start, end, ranges, logger) | ||
else: | ||
looger.error("PLOT TYPE DOES NOT EXIST: %s" % visualize) | ||
# Save and show the figure | ||
save_figure(fig, savename, logger, mov_stacks, comps, filterid, visualize, plot_all_period=False, start=start, end=end, outfile=outfile) | ||
if show: | ||
plt.show() | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
Oops, something went wrong.