Skip to content

Commit

Permalink
Finalize environment creation
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Dec 12, 2023
1 parent 63a7ac3 commit 0457152
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
3 changes: 2 additions & 1 deletion configs/pvred.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ collate_fn:
logger:
name: TensorboardLogger
args:
save_dir: runs
log_dir: runs

optimizer:
name: AdamW
Expand Down Expand Up @@ -63,3 +63,4 @@ runner:
block_size: 8
log_gradient_info: true
device: cuda
n_epochs: 100
64 changes: 49 additions & 15 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

from skelcast.models import MODELS
from skelcast.data import DATASETS
from skelcast.data import TRANSFORMS
from skelcast.data import TRANSFORMS, COLLATE_FUNCS
from skelcast.logger import LOGGERS
from skelcast.losses import LOSSES
from skelcast.losses.torch_losses import PYTORCH_LOSSES
from skelcast.core.optimizers import PYTORCH_OPTIMIZERS

from skelcast.experiments.runner import Runner
from skelcast.core.config import read_config, build_object_from_config
Expand Down Expand Up @@ -50,17 +53,25 @@ class Environment:
and configurations are properly set up before using this class.
"""
def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd',
checkpoint_dir = '/home/kaseris/Documents/checkpoints_forecasting') -> None:
checkpoint_dir = '/home/kaseris/Documents/mount/checkpoints_forecasting',
train_set_size = 0.8) -> None:
self._experiment_name = randomname.get_name()
self.checkpoint_dir = checkpoint_dir
self.checkpoint_dir = os.path.join(checkpoint_dir, self._experiment_name)
os.mkdir(self.checkpoint_dir)
logging.info(f'Created checkpoint directory at {self.checkpoint_dir}')
self.data_dir = data_dir
self.train_set_size = train_set_size
self.config = None
self._model = None
self._dataset = None
self._train_dataset = None
self._val_dataset = None
self._runner = None
self._logger = None
self._loss = None
self._optimizer = None
self._collate_fn = None


@property
def experiment_name(self) -> str:
Expand All @@ -69,12 +80,42 @@ def experiment_name(self) -> str:
def build_from_file(self, config_path: str) -> None:
logging.log(logging.INFO, f'Building environment from {config_path}.')
cfgs = read_config(config_path=config_path)
# Build tranforms first, because they are used in the dataset
self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS)
print(self._transforms)
# TODO: Add support for random splits
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS)


# TODO: Add support for random splits. Maybe as external parameter?
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS, transforms=self._transforms)
logging.info(f'Loaded dataset from {self.data_dir}.')
# Build the loss first, because it is used in the model
loss_registry = LOSSES if cfgs.criterion_config.get('name') not in PYTORCH_LOSSES else PYTORCH_LOSSES
self._loss = build_object_from_config(cfgs.criterion_config, loss_registry)
logging.info(f'Loaded loss function {cfgs.criterion_config.get("name")}.')
self._model = build_object_from_config(cfgs.model_config, MODELS, loss_fn=self._loss)
logging.info(f'Loaded model {cfgs.model_config.get("name")}.')
# Build the optimizer
self._optimizer = build_object_from_config(cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters())
logging.info(f'Loaded optimizer {cfgs.optimizer_config.get("name")}.')
# Build the logger
cfgs.logger_config.get('args').update({'log_dir': os.path.join(cfgs.logger_config.get('args').get('log_dir'), self._experiment_name)})
self._logger = build_object_from_config(cfgs.logger_config, LOGGERS)
logging.info(f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}')
# Build the collate_fn
self._collate_fn = build_object_from_config(cfgs.collate_fn_config, COLLATE_FUNCS)
logging.info(f'Loaded collate function {cfgs.collate_fn_config.get("name")}.')
# Split the dataset into training and validation sets
train_size = int(self.train_set_size * len(self._dataset))
val_size = len(self._dataset) - train_size
self._train_dataset, self._val_dataset = random_split(self._dataset, [train_size, val_size])
# Build the runner
self._runner = Runner(model=self._model,
optimizer=self._optimizer,
logger=self._logger,
collate_fn=self._collate_fn,
train_set=self._train_dataset,
val_set=self._val_dataset,
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get('args'))
logging.info(f'Finished building environment from {config_path}.')


def run(self) -> None:
# Must check if there is a checkpoint directory
Expand All @@ -83,10 +124,3 @@ def run(self) -> None:
# If there's not a checkpoint directory, use the self._runner.fit() method
# Otherwise, use the self._runner.resume(path_to_checkpoint) method
return self._runner.fit()

if __name__ == '__main__':
format = '%(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=format)
env = Environment()
env.build_from_file('configs/pvred.yaml')
# env.run()
1 change: 1 addition & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
args.add_argument('--config', type=str, default='../configs/lstm_regressor_1024x1024.yaml')
args.add_argument('--data_dir', type=str, default='data')
args.add_argument('--checkpoint_dir', type=str, default='checkpoints')
args.add_argument('--train_set_size', type=float, default=0.8, required=False)

args = args.parse_args()

Expand Down

0 comments on commit 0457152

Please sign in to comment.