Skip to content

Commit

Permalink
fix errors from _get_sampler()
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 4, 2024
1 parent 1fc92ff commit 5b4da2b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 117 deletions.
93 changes: 92 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import pytest

from torchhydro.configs.config import cmd, default_config_file
from tests.test_seq2seq import gage_id
from torchhydro.configs.config import cmd, default_config_file, update_cfg
from torchhydro import SETTING
import logging
import pandas as pd
Expand Down Expand Up @@ -454,3 +455,93 @@ def dpl_args():
opt="Adadelta",
which_first_tensor="sequence",
)


@pytest.fixture()
def seq2seq_config():
project_name = os.path.join("train_with_gpm", "ex_test")
config_data = default_config_file()
args = cmd(
sub=project_name,
source_cfgs={
"source": "HydroMean",
"source_path": SETTING["local_data_path"]["datasets-interim"],
},
ctx=[0],
model_name="Seq2Seq",
model_hyperparam={
"en_input_size": 17,
"de_input_size": 18,
"output_size": 2,
"hidden_size": 256,
"forecast_length": 56,
"prec_window": 1,
"teacher_forcing_ratio": 0.5,
},
model_loader={"load_way": "best"},
gage_id=gage_id,
# gage_id=["21400800", "21401550", "21401300", "21401900"],
batch_size=128,
forecast_history=240,
forecast_length=56,
min_time_unit="h",
min_time_interval=3,
var_t=[
"precipitationCal",
"sm_surface",
],
var_c=[
"area", # 面积
"ele_mt_smn", # 海拔(空间平均)
"slp_dg_sav", # 地形坡度 (空间平均)
"sgr_dk_sav", # 河流坡度 (平均)
"for_pc_sse", # 森林覆盖率
"glc_cl_smj", # 土地覆盖类型
"run_mm_syr", # 陆面径流 (流域径流的空间平均值)
"inu_pc_slt", # 淹没范围 (长期最大)
"cmi_ix_syr", # 气候湿度指数
"aet_mm_syr", # 实际蒸散发 (年平均)
"snw_pc_syr", # 雪盖范围 (年平均)
"swc_pc_syr", # 土壤水含量
"gwt_cm_sav", # 地下水位深度
"cly_pc_sav", # 土壤中的黏土、粉砂、砂粒含量
"dor_pc_pva", # 调节程度
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
scaler="DapengScaler",
train_epoch=2,
save_epoch=1,
train_mode=True,
train_period=["2016-06-01-01", "2016-08-01-01"],
test_period=["2015-06-01-01", "2015-08-01-01"],
valid_period=["2015-06-01-01", "2015-08-01-01"],
loss_func="MultiOutLoss",
loss_param={
"loss_funcs": "RMSESum",
"data_gap": [0, 0],
"device": [0],
"item_weight": [0.8, 0.2],
},
opt="Adam",
lr_scheduler={
"lr": 0.0001,
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
# ensemble_items={
# "batch_sizes": [256, 512],
# },
patience=10,
model_type="MTL",
)

# update the config data
update_cfg(config_data, args)

return config_data
6 changes: 3 additions & 3 deletions tests/test_deep_hydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,23 @@ def test_get_scheduler_lambda_lr_with_epochs_show_lr(deep_hydro, dummy_train_cfg


def test_get_sampler_basin_batch_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "BasinBatchSampler"}
dummy_train_cfgs["data_cfgs"]["sampler"] = "BasinBatchSampler"
sampler = deep_hydro._get_sampler(
dummy_train_cfgs["data_cfgs"], deep_hydro.traindataset
)
assert isinstance(sampler, BasinBatchSampler)


def test_get_sampler_kuai_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "KuaiSampler"}
dummy_train_cfgs["data_cfgs"]["sampler"] = "KuaiSampler"
sampler = deep_hydro._get_sampler(
dummy_train_cfgs["data_cfgs"], deep_hydro.traindataset
)
assert isinstance(sampler, KuaiSampler)


def test_get_sampler_invalid_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "InvalidSampler"}
dummy_train_cfgs["data_cfgs"]["sampler"] = "InvalidSampler"
with pytest.raises(
NotImplementedError, match="Sampler InvalidSampler not implemented yet"
):
Expand Down
98 changes: 2 additions & 96 deletions tests/test_seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:55:24
LastEditTime: 2024-11-02 21:47:38
LastEditTime: 2024-11-04 18:33:00
LastEditors: Wenyu Ouyang
Description: Test funcs for seq2seq model
FilePath: /torchhydro/tests/test_seq2seq.py
FilePath: \torchhydro\tests\test_seq2seq.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

Expand All @@ -13,16 +13,13 @@
from torchhydro.models.seq2seq import GeneralSeq2Seq

import logging
import os.path
import pathlib

import pandas as pd
import hydrodatasource.configs.config as hdscc
import xarray as xr
import torch.multiprocessing as mp

from torchhydro import SETTING
from torchhydro.configs.config import cmd, default_config_file, update_cfg
from torchhydro.trainers.deep_hydro import train_worker
from torchhydro.trainers.trainer import train_and_evaluate

Expand All @@ -39,97 +36,6 @@
]


@pytest.fixture()
def seq2seq_config():
project_name = os.path.join("train_with_gpm", "ex_test")
config_data = default_config_file()
args = cmd(
sub=project_name,
source_cfgs={
"source": "HydroMean",
"source_path": SETTING["local_data_path"]["datasets-interim"],
},
ctx=[0],
model_name="Seq2Seq",
model_hyperparam={
"en_input_size": 17,
"de_input_size": 18,
"output_size": 2,
"hidden_size": 256,
"forecast_length": 56,
"prec_window": 1,
"teacher_forcing_ratio": 0.5,
},
model_loader={"load_way": "best"},
gage_id=gage_id,
# gage_id=["21400800", "21401550", "21401300", "21401900"],
batch_size=128,
forecast_history=240,
forecast_length=56,
min_time_unit="h",
min_time_interval=3,
var_t=[
"precipitationCal",
"sm_surface",
],
var_c=[
"area", # 面积
"ele_mt_smn", # 海拔(空间平均)
"slp_dg_sav", # 地形坡度 (空间平均)
"sgr_dk_sav", # 河流坡度 (平均)
"for_pc_sse", # 森林覆盖率
"glc_cl_smj", # 土地覆盖类型
"run_mm_syr", # 陆面径流 (流域径流的空间平均值)
"inu_pc_slt", # 淹没范围 (长期最大)
"cmi_ix_syr", # 气候湿度指数
"aet_mm_syr", # 实际蒸散发 (年平均)
"snw_pc_syr", # 雪盖范围 (年平均)
"swc_pc_syr", # 土壤水含量
"gwt_cm_sav", # 地下水位深度
"cly_pc_sav", # 土壤中的黏土、粉砂、砂粒含量
"dor_pc_pva", # 调节程度
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=2,
save_epoch=1,
train_mode=True,
train_period=["2016-06-01-01", "2016-08-01-01"],
test_period=["2015-06-01-01", "2015-08-01-01"],
valid_period=["2015-06-01-01", "2015-08-01-01"],
loss_func="MultiOutLoss",
loss_param={
"loss_funcs": "RMSESum",
"data_gap": [0, 0],
"device": [0],
"item_weight": [0.8, 0.2],
},
opt="Adam",
lr_scheduler={
"lr": 0.0001,
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
# ensemble_items={
# "batch_sizes": [256, 512],
# },
patience=10,
model_type="MTL",
)

# 更新默认配置
update_cfg(config_data, args)

return config_data


def test_seq2seq(seq2seq_config):
# world_size = len(config["training_cfgs"]["device"])
# mp.spawn(train_worker, args=(world_size, config), nprocs=world_size, join=True)
Expand Down
31 changes: 14 additions & 17 deletions torchhydro/trainers/deep_hydro.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:15:48
LastEditTime: 2024-11-04 17:52:16
LastEditTime: 2024-11-04 18:40:05
LastEditors: Wenyu Ouyang
Description: HydroDL model class
FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py
Expand Down Expand Up @@ -494,9 +494,9 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
print(f"Pin memory set to {str(pin_memory)}")
sampler = None
if data_cfgs["sampler"] is not None:
sampler = self._get_sampler(data_cfgs, self.train_dataset)
sampler = self._get_sampler(data_cfgs, self.traindataset)
data_loader = DataLoader(
self.train_dataset,
self.traindataset,
batch_size=training_cfgs["batch_size"],
shuffle=(sampler is None),
sampler=sampler,
Expand All @@ -505,10 +505,9 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
timeout=0,
)
if data_cfgs["t_range_valid"] is not None:
batch_size_valid = training_cfgs["batch_size"]
validation_data_loader = DataLoader(
self.valid_dataset,
batch_size=batch_size_valid,
self.validdataset,
batch_size=training_cfgs["batch_size"],
shuffle=False,
num_workers=worker_num,
pin_memory=pin_memory,
Expand Down Expand Up @@ -554,21 +553,19 @@ def _get_sampler(self, data_cfgs, train_dataset):
horizon = data_cfgs["forecast_length"]
ngrid = train_dataset.ngrid
nt = train_dataset.nt
sampler_name = data_cfgs["sampler"]["name"]
sampler_name = data_cfgs["sampler"]
if sampler_name not in data_sampler_dict:
raise NotImplementedError(f"Sampler {sampler_name} not implemented yet")
sampler_class = data_sampler_dict[sampler_name]
sampler_hyperparam = data_cfgs["sampler"].get("sampler_hyperparam", {})
sampler_hyperparam = {}
if sampler_name == "KuaiSampler":
sampler_hyperparam.update(
{
"batch_size": batch_size,
"warmup_length": warmup_length,
"rho_horizon": rho + horizon,
"ngrid": ngrid,
"nt": nt,
}
)
sampler_hyperparam |= {
"batch_size": batch_size,
"warmup_length": warmup_length,
"rho_horizon": rho + horizon,
"ngrid": ngrid,
"nt": nt,
}
return sampler_class(train_dataset, **sampler_hyperparam)


Expand Down

0 comments on commit 5b4da2b

Please sign in to comment.