Skip to content

Commit

Permalink
fix issues from review
Browse files Browse the repository at this point in the history
  • Loading branch information
richardarsenault committed Jan 27, 2024
1 parent fd30354 commit ad4446f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 16 deletions.
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- spotpy
- statsmodels
- xarray
- xclim >=0.45.0
- xclim >=0.47.0
- xscen >=0.7.1
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- spotpy
- statsmodels
- xarray
- xclim >=0.45.0
- xclim >=0.47.0
- xscen >=0.7.1
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"spotpy",
"statsmodels",
"xarray",
"xclim>=0.45.0",
"xclim>=0.47.0",
"xdatasets>=0.3.1",
"xscen>=0.7.1"
]
Expand Down
60 changes: 48 additions & 12 deletions xhydro/modelling/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
Any comments are welcome!
"""

from copy import deepcopy

# Import packages
from typing import Optional

import numpy as np
import spotpy
import xarray as xr
Expand Down Expand Up @@ -135,7 +139,7 @@ def __init__(
take_negative: bool = False,
mask: np.array = None,
transform: str = None,
epsilon: float = None,
epsilon: float = 0.01,
):
"""
Initialize the SpotSetup object.
Expand Down Expand Up @@ -181,11 +185,11 @@ def __init__(
algorithm : str
The optimization algorithm to use. Currently, "DDS" and "SCEUA" are available, but more can be easily added.
take_negative : bool
Inidactor to take the negative of the objective function value in optimization to ensure convergence
Wether to take the negative of the objective function value in optimization to ensure convergence
in the right direction.
mask : np.array
mask : np.array, optional
A vector indicating which values to preserve/remove from the objective function computation. 0=remove, 1=preserve.
transform : str
transform : str, optional
The method to transform streamflow prior to computing the objective function. Can be one of:
Square root ('sqrt'), inverse ('inv'), or logarithmic ('log') transformation.
epsilon : float
Expand All @@ -199,7 +203,7 @@ def __init__(
"""
# Gather the model_config dictionary and obj_func string, and other
# optional arguments.
self.model_config = model_config
self.model_config = deepcopy(model_config)
self.obj_func = obj_func
self.mask = mask
self.transform = transform
Expand Down Expand Up @@ -305,9 +309,10 @@ def perform_calibration(
bounds_low: np.array,
evaluations: int,
algorithm: str = "DDS",
mask: np.array = None,
transform: str = None,
mask: Optional[np.array] = None,
transform: Optional[str] = None,
epsilon: float = 0.01,
sampler_kwargs: Optional[dict] = None,
):
"""Perform calibration using spotpy.
Expand Down Expand Up @@ -351,15 +356,21 @@ def perform_calibration(
Maximum number of model evaluations (calibration budget) to perform before stopping the calibration process.
algorithm : str
The optimization algorithm to use. Currently, "DDS" and "SCEUA" are available, but more can be easily added.
mask : np.array
mask : np.array, optional
A vector indicating which values to preserve/remove from the objective function computation. 0=remove, 1=preserve.
transform : str
transform : str, optional
The method to transform streamflow prior to computing the objective function. Can be one of:
Square root ('sqrt'), inverse ('inv'), or logarithmic ('log') transformation.
epsilon : scalar float
Used to add a small delta to observations for log and inverse transforms, to eliminate errors
caused by zero flow days (1/0 and log(0)). The added perturbation is equal to the mean observed streamflow
times this value of epsilon.
sampler_kwargs : dict
Contains the keywords and hyperparameter values for the optimization algorithm.
Keywords depend on the algorithm choice. Currently, SCEUA and DDS are supported with
the following default values:
- SCEUA: dict(ngs=7, kstop=3, peps=0.1, pcento=0.1)
- DDS: dict(trials=1)
Returns
-------
Expand All @@ -370,7 +381,7 @@ def perform_calibration(
bestobjf : float
The best objective function value.
"""
# Get objective function and algo optimal convregence direction. Necessary
# Get objective function and algo optimal convergence direction. Necessary
# to ensure that the algorithm is optimizing in the correct direction
# (maximizing or minimizing). This code determines the required direction
# for the objective function and the working direction of the algorithm.
Expand Down Expand Up @@ -401,13 +412,38 @@ def perform_calibration(
sampler = spotpy.algorithms.dds(
spotpy_setup, dbname="DDS_optim", dbformat="ram", save_sim=False
)
sampler.sample(evaluations, trials=1)

# If the user provided a custom sampler hyperparameter set.
if sampler_kwargs is not None:
if "trials" in sampler_kwargs:
sampler.sample(evaluations, *sampler_kwargs)
else:
raise ValueError(
'DDS optimizer hyperparameter keyword "trials" not found in sampler_kwargs.'
)

# If not, use the default.
else:
sampler.sample(evaluations, trials=1)

elif algorithm == "SCEUA":
sampler = spotpy.algorithms.sceua(
spotpy_setup, dbname="SCEUA_optim", dbformat="ram", save_sim=False
)
sampler.sample(evaluations, ngs=7, kstop=3, peps=0.1, pcento=0.1)

# If the user provided a custom sampler hyperparameter set.
if sampler_kwargs is not None:
if all(
item in sampler_kwargs for item in ["ngs", "kstop", "peps", "pcento"]
):
sampler.sample(evaluations, *sampler_kwargs)
else:
raise ValueError(
'SCEUA optimizer hyperparameter keywords "ngs", "kstop", "peps" or " pcento" not found in sampler_kwargs.'
)

else:
sampler.sample(evaluations, ngs=7, kstop=3, peps=0.1, pcento=0.1)

# Gather optimization results
results = sampler.getdata()
Expand Down
3 changes: 2 additions & 1 deletion xhydro/modelling/obj_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def get_objective_function(
}

# If we got a dataset, change to np.array
# FIXME: Implement a more flexible method
if isinstance(qsim, xr.Dataset):
qsim = qsim["qsim"]

Expand All @@ -165,7 +166,7 @@ def get_objective_function(

# All zero or one?
if not np.setdiff1d(np.unique(mask), np.array([0, 1])).size == 0:
raise ValueError("Mask contains values other 0 or 1. Please modify.")
raise ValueError("Mask contains values other than 0 or 1. Please modify.")

# Check that the objective function is in the list of available methods
if obj_func not in obj_func_dict:
Expand Down

0 comments on commit ad4446f

Please sign in to comment.