Skip to content

Commit

Permalink
selfmadecamels
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 19, 2023
1 parent 69db841 commit 46a8975
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 59 deletions.
36 changes: 19 additions & 17 deletions tests/test_tl_opendata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2023-10-05 16:16:48
LastEditTime: 2023-10-20 19:59:38
LastEditTime: 2023-11-19 22:01:36
LastEditors: Wenyu Ouyang
Description: Transfer learning for local basins with hydro_opendata
FilePath: \torchhydro\tests\test_tl_opendata.py
Expand All @@ -18,12 +18,12 @@
@pytest.fixture()
def var_c_target():
return [
"elev_mean",
"slope_mean",
"area_gages2",
"frac_forest",
"lai_max",
"lai_diff",
"p_mean",
"pet_mean",
"Area",
"geol_class_1st",
"elev",
"SNDPPT",
]


Expand Down Expand Up @@ -52,7 +52,8 @@ def var_c_source():

@pytest.fixture()
def var_t_target():
return ["dayl", "prcp", "srad"]
# mainly from ERA5LAND
return ["total_precipitation", "potential_evaporation", "temperature_2m"]


@pytest.fixture()
Expand All @@ -70,12 +71,14 @@ def test_transfer_gages_lstm_model(
"exp1",
)
weight_path = get_lastest_file_in_a_dir(weight_dir)
project_name = "test_caravan/exp6"
project_name = "test_camels/exptl4cc"
args = cmd(
sub=project_name,
source="Caravan",
source_path=os.path.join(hds.ROOT_DIR, "caravan"),
source_region="Global",
source="SelfMadeCAMELS",
# cc means China continent
source_path=os.path.join(
hds.ROOT_DIR, "waterism", "datasets-interim", "camels_cc"
),
download=0,
ctx=[0],
model_type="TransLearn",
Expand All @@ -91,8 +94,8 @@ def test_transfer_gages_lstm_model(
batch_size=5,
rho=20,
rs=1234,
train_period=["2010-10-01", "2011-10-01"],
test_period=["2011-10-01", "2012-10-01"],
train_period=["2014-10-01", "2019-10-01"],
test_period=["2019-10-01", "2021-10-01"],
scaler="DapengScaler",
sampler="KuaiSampler",
dataset="StreamflowDataset",
Expand All @@ -108,9 +111,8 @@ def test_transfer_gages_lstm_model(
var_c=var_c_target,
var_out=["streamflow"],
gage_id=[
"01055000",
"01057000",
"01170100",
"61561",
"62618",
],
)
cfg = default_config_file()
Expand Down
4 changes: 3 additions & 1 deletion torchhydro/datasets/data_dict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2023-10-17 14:00:32
LastEditTime: 2023-11-19 11:15:32
LastEditors: Wenyu Ouyang
Description: A dict used for data source and data loader
FilePath: \torchhydro\torchhydro\datasets\data_dict.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""
from hydrodataset import Camels
from hydrodataset.caravan import Caravan
from datasets.data_sources import SelfMadeCamels

from torchhydro.datasets.data_sets import (
BaseDataset,
Expand All @@ -19,6 +20,7 @@
data_sources_dict = {
"CAMELS": Camels,
"Caravan": Caravan,
"SelfMadeCAMELS": SelfMadeCamels
}

datasets_dict = {
Expand Down
Loading

0 comments on commit 46a8975

Please sign in to comment.