Skip to content

Commit

Permalink
added tests for objfuncs and changed calling method
Browse files Browse the repository at this point in the history
  • Loading branch information
richardarsenault committed Dec 29, 2023
1 parent fbe6b81 commit 6460a88
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 25 deletions.
68 changes: 68 additions & 0 deletions tests/test_objective_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Test suite for the objective functions in obj_funcs.py."""

import numpy as np

from xhydro.modelling.obj_funcs import get_objective_function


def test_obj_funcs():
"""
Series of tests to test all objective functions with fast test data
"""
Qobs = np.array([120, 130, 140, 150, 160, 170])
Qsim = np.array([120, 125, 145, 140, 140, 180])

# Test that the objective function is calculated correctly
objfun = get_objective_function(Qobs, Qsim, obj_func="abs_bias")
np.testing.assert_array_almost_equal(objfun, 3.3333333333333335, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="abs_pbias")
np.testing.assert_array_almost_equal(objfun, 2.2988505747126435, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="abs_volume_error")
np.testing.assert_array_almost_equal(objfun, 0.022988505747126436, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="agreement_index")
np.testing.assert_array_almost_equal(objfun, 0.9171974522292994, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="bias")
np.testing.assert_array_almost_equal(objfun, -3.3333333333333335, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="correlation_coeff")
np.testing.assert_array_almost_equal(objfun, 0.8599102447336393, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="kge")
np.testing.assert_array_almost_equal(objfun, 0.8077187696552522, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="kge_mod")
np.testing.assert_array_almost_equal(objfun, 0.7888769531580001, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="mae")
np.testing.assert_array_almost_equal(objfun, 8.333333333333334, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="mare")
np.testing.assert_array_almost_equal(objfun, 0.05747126436781609, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="mse")
np.testing.assert_array_almost_equal(objfun, 108.33333333333333, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="nse")
np.testing.assert_array_almost_equal(objfun, 0.6285714285714286, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="pbias")
np.testing.assert_array_almost_equal(objfun, -2.2988505747126435, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="r2")
np.testing.assert_array_almost_equal(objfun, 0.7394456289978675, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="rmse")
np.testing.assert_array_almost_equal(objfun, 10.408329997330663, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="rrmse")
np.testing.assert_array_almost_equal(objfun, 0.07178158618848733, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="rsr")
np.testing.assert_array_almost_equal(objfun, 0.6094494002200439, 8)

objfun = get_objective_function(Qobs, Qsim, obj_func="volume_error")
np.testing.assert_array_almost_equal(objfun, -0.022988505747126436, 8)
48 changes: 23 additions & 25 deletions xhydro/modelling/obj_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,26 +128,26 @@ def get_objective_function(
function (obj_fun).
"""
# List of available objective functions
obj_func_list = [
"abs_bias",
"abs_pbias",
"abs_volume_error",
"agreement_index",
"bias",
"correlation_coeff",
"kge",
"kge_mod",
"mae",
"mare",
"mse",
"nse",
"pbias",
"r2",
"rmse",
"rrmse",
"rsr",
"volume_error",
]
obj_func_dict = {
"abs_bias": abs_bias,
"abs_pbias": abs_pbias,
"abs_volume_error": abs_volume_error,
"agreement_index": agreement_index,
"bias": bias,
"correlation_coeff": correlation_coeff,
"kge": kge,
"kge_mod": kge_mod,
"mae": mae,
"mare": mare,
"mse": mse,
"nse": nse,
"pbias": pbias,
"r2": r2,
"rmse": rmse,
"rrmse": rrmse,
"rsr": rsr,
"volume_error": volume_error,
}

# Basic error checking
if Qobs.shape[0] != Qsim.shape[0]:
Expand All @@ -164,7 +164,7 @@ def get_objective_function(
sys.exit("Mask contains values other 0 or 1. Please modify.")

# Check that the objective function is in the list of available methods
if obj_func not in obj_func_list:
if obj_func not in obj_func_dict:
sys.exit(
"Selected objective function is currently unavailable. "
+ "Consider contributing to our project at: "
Expand All @@ -189,16 +189,14 @@ def get_objective_function(

# Compute objective function by switching to the correct algorithm. Ensure
# that the function name is the same as the obj_func tag or this will fail.
function_call = globals()[obj_func]
function_call = obj_func_dict[obj_func]
obj_fun_val = function_call(Qsim, Qobs)

# Take the negative value of the Objective Function to return to the
# optimizer.
if take_negative:
obj_fun_val = obj_fun_val * -1

print(obj_fun_val)

return obj_fun_val


Expand Down Expand Up @@ -384,7 +382,7 @@ def abs_pbias(Qsim, Qobs):
The abs_pbias should be MINIMIZED.
"""
return np.abs(bias(Qsim, Qobs))
return np.abs(pbias(Qsim, Qobs))


def abs_volume_error(Qsim, Qobs):
Expand Down

0 comments on commit 6460a88

Please sign in to comment.