Skip to content

Commit

Permalink
Entrypoint and format fix in logging and checkpointing (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris authored Dec 1, 2023
1 parent 707732e commit 578a2b5
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 9 deletions.
42 changes: 42 additions & 0 deletions configs/lstm_regressor_2048x1024.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
dataset:
name: 'NTURGBDDataset'
args:
missing_files_dir: 'data/missing'
label_file: 'data/labels.txt'
max_context_window: 10
max_number_of_bodies: 1
transforms:
name: 'MinMaxScaleTransform'
args:
feature_scale: [0.0, 1.0]
max_duration: 300
n_joints: 25

# Set the train data percentage
train_data_percentage: 0.8

model:
name: 'SimpleLSTMRegressor'
args:
hidden_size: 2048
num_layers: 2
linear_out: 1024
reduction: 'mean'
batch_first: true
n_joints: 25
n_dims: 3

runner:
args:
val_batch_size: 32
train_batch_size: 32
block_size: 8
device: 'cuda'
logger:
name: 'TensorboardLogger'
args:
save_dir: 'runs'
checkpoint_dir: '/home/kaseris/Documents/checkpoints_forecasting'
n_epochs: 10
lr: 0.00001
log_gradient_info: true
7 changes: 6 additions & 1 deletion src/skelcast/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import torch

from skelcast.callbacks.callback import Callback
Expand Down Expand Up @@ -38,6 +40,9 @@ def save_checkpoint(self, runner, epoch: int):
- runner (Runner): The experiment runner instance
- epoch (int): The current epoch
"""
now = datetime.now()
formatted_time = now.strftime("%Y-%m-%d_%H%M%S")

checkpoint = {
'epoch': epoch,
'model_state_dict': runner.model.state_dict(),
Expand All @@ -48,6 +53,6 @@ def save_checkpoint(self, runner, epoch: int):
'validation_loss_per_step': runner.validation_loss_per_step
}

checkpoint_path = f'{self.checkpoint_dir}/checkpoint_epoch_{epoch}.pt'
checkpoint_path = f'{self.checkpoint_dir}/checkpoint_epoch_{epoch}_{formatted_time}.pt'
torch.save(checkpoint, checkpoint_path)

5 changes: 4 additions & 1 deletion src/skelcast/callbacks/console.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
from datetime import datetime

from skelcast.callbacks.callback import Callback

Expand Down Expand Up @@ -47,8 +48,10 @@ def on_epoch_end(self, epoch, epoch_loss, phase):
print()

def _print_status(self):
now = datetime.now()
now_formatted = now.strftime("[%Y-%m-%d %H:%M:%S]")
clear_line = '\r' + ' ' * 80 # Create a line of 80 spaces
message = f"Epoch: {self.current_epoch + 1}/{self.final_epoch}, Batch: {self.current_batch}/{self.total_batches}, Train Loss: {self.latest_train_loss}, Val Loss: {self.latest_val_loss}"
message = f"{now_formatted} Epoch: {self.current_epoch + 1}/{self.final_epoch}, Batch: {self.current_batch}/{self.total_batches}, Train Loss: {self.latest_train_loss}, Val Loss: {self.latest_val_loss}"

# First, print the clear_line to overwrite the previous output, then print your message
print(f'{clear_line}\r{message}', end='')
Expand Down
11 changes: 4 additions & 7 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,16 @@ def _build_runner(self) -> None:
_args['train_set'] = self._train_dataset
_args['val_set'] = self._val_dataset
_args['checkpoint_dir'] = os.path.join(self.checkpoint_dir, self._experiment_name)
self._create_checkpoint_dir()
self._runner = Runner(**_args)
self._runner.setup()
logging.log(logging.INFO, 'Runner setup complete.')

def _create_checkpoint_dir(self) -> None:
if os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)):
raise ValueError(f'Checkpoint directory {self.checkpoint_dir} already exists.')
raise ValueError(f'Checkpoint directory {os.path.join(self.checkpoint_dir, self._experiment_name)} already exists.')
else:
logging.log(logging.INFO, f'Creating checkpoint directory: {self.checkpoint_dir}.')
logging.log(logging.INFO, f'Creating checkpoint directory: {os.path.join(self.checkpoint_dir, self._experiment_name)}.')
os.mkdir(os.path.join(self.checkpoint_dir, self._experiment_name))

def _parse_file(self, fname: str) -> None:
Expand All @@ -137,8 +138,4 @@ def run(self) -> None:
# Else, create a new checkpoint directory and start training
# If there's not a checkpoint directory, use the self._runner.fit() method
# Otherwise, use the self._runner.resume(path_to_checkpoint) method
if not os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)):
self._create_checkpoint_dir()
return self._runner.fit()
else:
return self._runner.resume(os.path.join(self.checkpoint_dir, self._experiment_name))
return self._runner.fit()
21 changes: 21 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging
from argparse import ArgumentParser

from skelcast.core.environment import Environment

args = ArgumentParser()
args.add_argument('--config', type=str, default='../configs/lstm_regressor_1024x1024.yaml')
args.add_argument('--data_dir', type=str, default='data')
args.add_argument('--checkpoint_dir', type=str, default='checkpoints')

args = args.parse_args()


if __name__ == '__main__':
log_format = '[%(asctime)s] %(levelname)s: %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format)

env = Environment(data_dir=args.data_dir, checkpoint_dir=args.checkpoint_dir)
env.build_from_file(args.config)
env.run()

0 comments on commit 578a2b5

Please sign in to comment.