From 81359796274031f6f6f2e53eae3f5e3a627a9e1c Mon Sep 17 00:00:00 2001 From: richardarsenault Date: Tue, 23 Jan 2024 21:37:53 -0500 Subject: [PATCH] added model_config helper function and tests --- tests/test_hydrological_modelling.py | 21 +++++++++++++- xhydro/modelling/hydrological_modelling.py | 32 ++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_hydrological_modelling.py b/tests/test_hydrological_modelling.py index 1840d5f3..2bccb4a5 100644 --- a/tests/test_hydrological_modelling.py +++ b/tests/test_hydrological_modelling.py @@ -2,7 +2,10 @@ import numpy as np import pytest -from xhydro.modelling.hydrological_modelling import run_hydrological_model +from xhydro.modelling.hydrological_modelling import ( + get_hydrological_model_inputs, + run_hydrological_model, +) def test_hydrological_modelling(): @@ -31,3 +34,19 @@ def test_import_unknown_model(): model_config = {"model_name": "fake_model"} _ = run_hydrological_model(model_config) assert pytest_wrapped_e.type == NotImplementedError + + +def test_get_unknown_model_requirements(): + """Test for required inputs for models with unknown name""" + with pytest.raises(NotImplementedError) as pytest_wrapped_e: + model_name = "fake_model" + _ = get_hydrological_model_inputs(model_name) + assert pytest_wrapped_e.type == NotImplementedError + + +def test_get_model_requirements(): + """Test for required inputs for models""" + model_name = "Dummy" + required_config = get_hydrological_model_inputs(model_name) + print(required_config.keys()) + assert len(required_config.keys()) == 4 diff --git a/xhydro/modelling/hydrological_modelling.py b/xhydro/modelling/hydrological_modelling.py index be9af7d0..a4baab0b 100644 --- a/xhydro/modelling/hydrological_modelling.py +++ b/xhydro/modelling/hydrological_modelling.py @@ -74,6 +74,38 @@ def run_hydrological_model(model_config: dict): return qsim +def get_hydrological_model_inputs(model_name: str): + """Required hydrological model inputs for model_config objects. + + Parameters + ---------- + model_name : str + Model name that must be one of the models in the list of possible + models. + + Returns + ------- + dict + Elements that must be found in the model_config object. + """ + if model_name == "Dummy": + required_config = dict( + precip="Daily precipitation in mm.", + temperature="Daily average air temperature in °C", + drainage_area="Drainage area of the catchment", + parameters="Model parameters, length 3", + ) + + elif model_name == "ADD_OTHER_HERE": + # ADD OTHER MODELS HERE + required_config = {} + + else: + raise NotImplementedError(f"The model '{model_name}' is not recognized.") + + return required_config + + def _dummy_model(model_config: dict): """Dummy model.