Skip to content

Commit

Permalink
can train caravan
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Oct 19, 2023
1 parent e80236a commit 49a1d48
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
11 changes: 10 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Author: Wenyu Ouyang
Date: 2023-07-31 08:40:43
LastEditTime: 2023-09-25 08:22:43
LastEditTime: 2023-10-19 08:37:56
LastEditors: Wenyu Ouyang
Description: Test some functions for dataset
FilePath: /torchhydro/tests/test_data.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""
import pytest
import hydrodataset as hds
from hydrodataset.caravan import Caravan
from torch.utils.data import Dataset

from datasets.sampler import KuaiSampler
Expand Down Expand Up @@ -66,3 +67,11 @@ def test_cache_file():
"""
camels_us = hds.Camels()
camels_us.cache_xrdataset()


def test_cache_caravan():
"""
Test whether the cache file is generated correctly
"""
caravan = Caravan(hds.ROOT_DIR.joinpath("caravan"))
caravan.cache_xrdataset()
8 changes: 3 additions & 5 deletions torchhydro/models/cudnnlstm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: MHPI group, Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2023-10-10 20:28:18
LastEditTime: 2023-10-11 11:38:48
LastEditors: Wenyu Ouyang
Description: LSTM with dropout implemented by Kuai Fang and more LSTMs using it
FilePath: \torchhydro\torchhydro\models\cudnnlstm.py
FilePath: /torchhydro/torchhydro/models/cudnnlstm.py
Copyright (c) 2021-2022 MHPI group, Wenyu Ouyang. All rights reserved.
"""

Expand Down Expand Up @@ -121,9 +121,7 @@ def forward(self, x, hidden, *, do_reset_mask=True, do_drop_mc=False):


class CpuLstmModel(nn.Module):
"""
Cpu version of CudnnLstmModel , ,
"""
"""Cpu version of CudnnLstmModel"""

def __init__(self, *, n_input_features, n_output_features, n_hidden_states, dr=0.5):
super(CpuLstmModel, self).__init__()
Expand Down

0 comments on commit 49a1d48

Please sign in to comment.