Skip to content

Commit

Permalink
Merge pull request #38 from MeteoSwiss/MPC-58-Integrate-training-with…
Browse files Browse the repository at this point in the history
…-MLflow

MPC-58 integrate training with mlflow
  • Loading branch information
wolfidan authored Dec 19, 2024
2 parents 727f45a + 852b775 commit b42330b
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 178 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_rainforest_master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
os: [ "ubuntu-latest" ]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10","3.11", "3.12"]
max-parallel: 6

defaults:
Expand Down
46 changes: 46 additions & 0 deletions rainforest/common/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
import os

REFCOLORS = OrderedDict()
REFCOLORS['RZC'] = 'k'
Expand Down Expand Up @@ -374,6 +375,51 @@ def qpe_scatterplot( qpe_est, ref, title_prefix = '', figsize = (10,7.5)):
cax = fig.add_axes([0.18, 0.15, 0.7, 0.03])
fig.colorbar(pl, cax, orientation = 'horizontal', label = 'Counts')

def plot_fit_metrics(metrics, output_folder):
"""
Plots the metrics for a given fit as performed in
the rf.py module
Parameters
----------
metrics : dict
dictionary containing the result statistics as obtained in the
rf.py:fit_models function
output_folder : str
where to store the plots
"""
aggregations = list(metrics.keys())
fractions = list(metrics[aggregations[0]].keys())
precip_types = list(metrics[aggregations[0]][fractions[0]].keys())
intensity_ranges = list(metrics[aggregations[0]][fractions[0]][precip_types[0]].keys())
metric_names = list(metrics[aggregations[0]][fractions[0]][precip_types[0]][intensity_ranges[0]].keys())

list_figures = {} # for mlflow
for ag in aggregations:
for f in fractions:
fig, ax = plt.subplots(len(metric_names), len(intensity_ranges), figsize=(14,14))
for i, metric in enumerate(metric_names):
for j, ir in enumerate(intensity_ranges):
ax[i,j].bar(precip_types, [metrics[ag][f][p][ir][metric] for p in precip_types], color = 'C0')
ax[i,j].grid()
if i != len(metric_names) - 1:
ax[i,j].set(xticklabels=[])
if j == 0:
ax[i,j].set_ylabel(metric)
ax[i,0].tick_params(bottom=False) # remove the ticks

for j in range(len(intensity_ranges)):
ax[-1, j].set_xlabel(intensity_ranges[j])
fig.suptitle(f"fraction={f}, aggregation={ag}")
nfile = f'metrics_{ag}_{f}.png'
list_figures[nfile] = fig
plt.savefig(os.path.join(output_folder, nfile), dpi = 300,
bbox_inches = 'tight')
return list_figures



def plot_crossval_stats(stats, output_folder):
"""
Expand Down
59 changes: 32 additions & 27 deletions rainforest/ml/default_config.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
n_jobs: 4
RFO:
FILTERING: # conditions to remove some observations
STA_TO_REMOVE : ['TIT','GSB','GRH','PIL','SAE','AUB']
CONSTRAINT_MIN_ZH : [0.5,5] # min 5 dBZ if R > 0.5 mm/h
CONSTRAINT_MAX_ZH : [0,20] # max 20 dBZ if R = 0 mm/h
RANDOMFOREST_REGRESSOR: # parameters to sklearn's class
max_depth : 20
n_estimators : 15
max_features : 7
VERT_AGG:
BETA : -0.5 # weighting factor to use in the exponential weighting
VISIB_WEIGHTING : 1 # whether to weigh or not observations by their visib
BIAS_CORR : 'spline' # type of bias correction 'raw', 'cdf' or 'spline'
RFO_hpol:
FILTERING: # conditions to remove some observations
STA_TO_REMOVE : ['TIT','GSB','GRH','PIL','SAE','AUB']
CONSTRAINT_MIN_ZH : [0.5,5] # min 5 dBZ if R > 0.5 mm/h
CONSTRAINT_MAX_ZH : [0,20] # max 20 dBZ if R = 0 mm/h
RANDOMFOREST_REGRESSOR: # parameters to sklearn's class
max_depth : 20
n_estimators : 15
max_features : 7
VERT_AGG:
BETA : -0.5 # weighting factor to use in the exponential weighting
VISIB_WEIGHTING : 1 # whether to weigh or not observations by their visib
BIAS_CORR : 'spline' # type of bias correction 'raw', 'cdf' or 'spline'
PARAMETERS:
n_jobs: 4
cv_precip_bounds: [0,2,10,100] # precip bounds used in cross-validation
MLFLOW:
cv_scores_to_log: ['60min,test,all,all,RMSE', '60min,test,all,all,ED', '60min,test,all,all,logBias', '60min,test,all,all,scatter']
MODELS:
RFO:
FILTERING: # conditions to remove some observations
sta_to_remove : ['TIT','GSB','GRH','PIL','SAE','AUB']
constraint_min_zh : [0.5,5] # min 5 dBZ if R > 0.5 mm/h
constraint_max_zh : [0,20] # max 20 dBZ if R = 0 mm/h
RANDOMFOREST_REGRESSOR: # parameters to sklearn's class
max_depth : 20
n_estimators : 15
max_features : 7
VERT_AGG:
beta : -0.5 # weighting factor to use in the exponential weighting
visib_weighting : 1 # whether to weigh or not observations by their visib
bias_corr : 'spline' # type of bias correction 'raw', 'cdf' or 'spline'
RFO_hpol:
FILTERING: # conditions to remove some observations
sta_to_remove : ['TIT','GSB','GRH','PIL','SAE','AUB']
constraint_min_zh : [0.5,5] # min 5 dBZ if R > 0.5 mm/h
constraint_max_zh : [0,20] # max 20 dBZ if R = 0 mm/h
RANDOMFOREST_REGRESSOR: # parameters to sklearn's class
max_depth : 20
n_estimators : 15
max_features : 7
VERT_AGG:
beta : -0.5 # weighting factor to use in the exponential weighting
visib_weighting : 1 # whether to weigh or not observations by their visib
bias_corr : 'spline' # type of bias correction 'raw', 'cdf' or 'spline'
Loading

0 comments on commit b42330b

Please sign in to comment.