Skip to content

Commit

Permalink
fix a logger bug for continue_train
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Oct 29, 2024
1 parent ea96ed1 commit 4e502d5
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
27 changes: 26 additions & 1 deletion tests/test_train_camels_lstm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
"""
Author: Wenyu Ouyang
Date: 2023-07-25 16:47:19
LastEditTime: 2024-04-10 21:00:10
LastEditTime: 2024-10-29 14:29:55
LastEditors: Wenyu Ouyang
Description: Test a full training and evaluating process
FilePath: \torchhydro\tests\test_train_camels_lstm.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

import os
from torchhydro.configs.config import update_cfg
from torchhydro.trainers.trainer import train_and_evaluate


def test_train_evaluate(args, config_data):
update_cfg(config_data, args)
train_and_evaluate(config_data)


def test_train_evaluate_continue(args, config_data):
"""We test the training and evaluation process with the continue_train
flag set to 1 and the start_epoch set to 2. This will load a pretrained
model and continue training.
This pattern is useful for training a model
when its training is interrupted
Parameters
----------
args : _type_
basic args in conftest.py
config_data : _type_
default config data
"""
args.continue_train = 1
args.start_epoch = 2
args.train_mode = 1
update_cfg(config_data, args)
config_data["model_cfgs"]["weight_path"] = os.path.join(
config_data["data_cfgs"]["test_path"], "model_Ep1.pth"
)
train_and_evaluate(config_data)
20 changes: 19 additions & 1 deletion torchhydro/trainers/train_logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2024-09-18 15:40:10
LastEditTime: 2024-10-29 16:07:08
LastEditors: Wenyu Ouyang
Description: Training function for DL models
FilePath: \torchhydro\torchhydro\trainers\train_logger.py
Expand All @@ -18,6 +18,8 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from torchhydro.trainers.train_utils import get_lastest_logger_file_in_a_dir


def save_model(model, model_file, gpu_num=1):
try:
Expand Down Expand Up @@ -48,6 +50,22 @@ def __init__(self, model_filepath, params, opt):
self.train_time = []
# log loss for each epoch
self.epoch_loss = []
# reload previous logs if continue_train is True and weight_path is not None
if (
self.model_cfgs["continue_train"]
and self.model_cfgs["weight_path"] is not None
):
the_logger_file = get_lastest_logger_file_in_a_dir(self.training_save_dir)
if the_logger_file is not None:
with open(the_logger_file, "r") as f:
logs = json.load(f)
start_epoch = self.training_cfgs["start_epoch"]
# read the logs before start_epoch and load them to session_params, train_time, epoch_loss
for log in logs["run"]:
if log["epoch"] < start_epoch:
self.session_params.append(log)
self.train_time.append(log["train_time"])
self.epoch_loss.append(float(log["train_loss"]))

def save_session_param(
self, epoch, total_loss, n_iter_ep, valid_loss=None, valid_metrics=None
Expand Down
30 changes: 28 additions & 2 deletions torchhydro/trainers/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:16:26
LastEditTime: 2024-09-18 10:22:54
LastEditTime: 2024-10-29 15:47:51
LastEditors: Wenyu Ouyang
Description: Some basic functions for training
FilePath: \torchhydro\torchhydro\trainers\train_utils.py
Expand All @@ -25,7 +25,11 @@
from torch.utils.data import DataLoader

from hydroutils.hydro_stat import stat_error
from hydroutils.hydro_file import get_lastest_file_in_a_dir, unserialize_json
from hydroutils.hydro_file import (
get_lastest_file_in_a_dir,
unserialize_json,
get_latest_file_in_a_lst,
)

from torchhydro.models.crits import GaussianLoss

Expand Down Expand Up @@ -576,6 +580,28 @@ def read_pth_from_model_loader(model_loader, model_pth_dir):
return weight_path


def get_lastest_logger_file_in_a_dir(dir_path):
"""Get the last logger file in a directory
Parameters
----------
dir_path : str
the directory
Returns
-------
str
the path of the logger file
"""
pattern = r"^\d{1,2}_[A-Za-z]+_\d{6}_\d{2}(AM|PM)\.json$"
pth_files_lst = [
os.path.join(dir_path, file)
for file in os.listdir(dir_path)
if re.match(pattern, file)
]
return get_latest_file_in_a_lst(pth_files_lst)


def cellstates_when_inference(seq_first, data_cfgs, pred):
"""get cell states when inference"""
cs_out = (
Expand Down

0 comments on commit 4e502d5

Please sign in to comment.