-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from kaseris/fix/environment-build
Fix/environment build
- Loading branch information
Showing
7 changed files
with
229 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
dataset: | ||
name: NTURGBDDataset | ||
args: | ||
data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd | ||
label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt | ||
missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing | ||
max_context_window: 10 | ||
max_number_of_bodies: 1 | ||
max_duration: 300 | ||
n_joints: 25 | ||
cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl | ||
|
||
transforms: | ||
- name: MinMaxScaleTransform | ||
args: | ||
feature_scale: [0.0, 1.0] | ||
- name: SomeOtherTransform | ||
args: | ||
some_arg: some_value | ||
|
||
loss: | ||
name: MSELoss | ||
args: | ||
reduction: mean | ||
|
||
collate_fn: | ||
name: NTURGBDCollateFnWithRandomSampledContextWindow | ||
args: | ||
block_size: 25 | ||
|
||
logger: | ||
name: TensorboardLogger | ||
args: | ||
save_dir: runs | ||
|
||
optimizer: | ||
name: AdamW | ||
args: | ||
lr: 0.0001 | ||
weight_decay: 0.0001 | ||
|
||
model: | ||
name: PositionalVelocityRecurrentEncoderDecoder | ||
args: | ||
input_dim: 75 | ||
enc_hidden_dim: 64 | ||
dec_hidden_dim: 64 | ||
enc_type: lstm | ||
dec_type: lstm | ||
include_velocity: false | ||
pos_enc: add | ||
batch_first: true | ||
std_thresh: 0.0001 | ||
use_std_mask: false | ||
use_padded_len_mask: true | ||
observe_until: 20 | ||
|
||
runner: | ||
name: Runner | ||
args: | ||
train_batch_size: 32 | ||
val_batch_size: 32 | ||
block_size: 8 | ||
log_gradient_info: true | ||
device: cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import re | ||
import logging | ||
import yaml | ||
|
||
from collections import OrderedDict | ||
from typing import Any, List | ||
|
||
from skelcast.core.registry import Registry | ||
|
||
|
||
class Config: | ||
def __init__(self): | ||
self._config = OrderedDict() | ||
self._config['name'] = None | ||
self._config['args'] = {} | ||
|
||
def get(self, key): | ||
return self._config[key] | ||
|
||
def set(self, key, value): | ||
if isinstance(value, list): | ||
self._config[key] = [] | ||
for v in value: | ||
self._config[key].append(v) | ||
else: | ||
self._config[key] = value | ||
|
||
def __str__(self) -> str: | ||
s = self.__class__.__name__ + '(\n' | ||
for key, val in self._config.items(): | ||
if isinstance(val, dict): | ||
s += f'\t{key}: \n' | ||
for k, v in val.items(): | ||
s += f'\t\t{k}: {v}\n' | ||
s += '\t\n' | ||
else: | ||
s += f'\t{key}: {val}\n' | ||
s += ')' | ||
return s | ||
|
||
|
||
class ModelConfig(Config): | ||
def __init__(self): | ||
super(ModelConfig, self).__init__() | ||
|
||
|
||
class DatasetConfig(Config): | ||
def __init__(self): | ||
super(DatasetConfig, self).__init__() | ||
|
||
|
||
class TransformsConfig(Config): | ||
def __init__(self): | ||
super(TransformsConfig, self).__init__() | ||
|
||
|
||
class LoggerConfig(Config): | ||
def __init__(self): | ||
super(LoggerConfig, self).__init__() | ||
|
||
|
||
class OptimizerConfig(Config): | ||
def __init__(self): | ||
super(OptimizerConfig, self).__init__() | ||
|
||
|
||
class SchedulerConfig(Config): | ||
def __init__(self): | ||
super(SchedulerConfig, self).__init__() | ||
|
||
|
||
class CriterionConfig(Config): | ||
def __init__(self): | ||
super(CriterionConfig, self).__init__() | ||
|
||
|
||
class CollateFnConfig(Config): | ||
def __init__(self): | ||
super(CollateFnConfig, self).__init__() | ||
|
||
|
||
class RunnerConfig(Config): | ||
def __init__(self): | ||
super(RunnerConfig, self).__init__() | ||
|
||
class EnvironmentConfig: | ||
def __init__(self, *args) -> None: | ||
for arg in args: | ||
name = arg.__class__.__name__ | ||
split_name = re.findall('[A-Z][^A-Z]*', name) | ||
name = '_'.join([s.lower() for s in split_name]) | ||
setattr(self, name, arg) | ||
|
||
def __str__(self) -> str: | ||
s = self.__class__.__name__ + '(\n' | ||
for key, val in self.__dict__.items(): | ||
s += f'\t{key}: {val}\n' | ||
s += ')' | ||
return s | ||
|
||
|
||
def build_object_from_config(config: Config, registry: Registry, **kwargs): | ||
_name = config.get('name') | ||
_args = config.get('args') | ||
_args.update(kwargs) | ||
return registry.get_module(_name)(**_args) | ||
|
||
def summarize_config(configs: List[Config]): | ||
with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f: | ||
for config in configs: | ||
f.write(str(config)) | ||
f.write('\n\n') | ||
|
||
|
||
CONFIG_MAPPING = { | ||
'model': ModelConfig, | ||
'dataset': DatasetConfig, | ||
'transforms': TransformsConfig, | ||
'logger': LoggerConfig, | ||
'optimizer': OptimizerConfig, | ||
'scheduler': SchedulerConfig, | ||
'loss': CriterionConfig, | ||
'collate_fn': CollateFnConfig, | ||
'runner': RunnerConfig | ||
} | ||
|
||
def read_config(config_path: str): | ||
with open(config_path, 'r') as f: | ||
data = yaml.safe_load(f) | ||
cfgs = [] | ||
for key in data: | ||
config = CONFIG_MAPPING[key]() | ||
if key == 'transforms': | ||
for element in data[key]: | ||
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') | ||
config.set(element['name'], element['args']) | ||
else: | ||
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') | ||
config.set('name', data[key]['name']) | ||
config.set('args', data[key]['args']) | ||
logging.debug(config) | ||
cfgs.append(config) | ||
return EnvironmentConfig(*cfgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import torch.optim as optim | ||
|
||
|
||
PYTORCH_OPTIMIZERS = {name: getattr(optim, name) for name in dir(optim) if isinstance(getattr(optim, name), type) and issubclass(getattr(optim, name), optim.Optimizer)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import torch.nn as nn | ||
|
||
PYTORCH_LOSSES = {name: getattr(nn, name) for name in dir(nn) if isinstance(getattr(nn, name), type) and issubclass(getattr(nn, name), nn.Module) and 'Loss' in name} |