diff --git a/configs/sttf_human36m.yaml b/configs/sttf_human36m.yaml new file mode 100644 index 0000000..d7796fe --- /dev/null +++ b/configs/sttf_human36m.yaml @@ -0,0 +1,48 @@ +dataset: + name: Human36MDataset + args: + path: /media/kaseris/FastData/Human3.6M-DMGNN/h36m.npz + +loss: + name: SmoothL1Loss + args: + reduction: mean + beta: 0.01 + +collate_fn: + name: Human36MCollateFnWithRandomSampledContextWindow + args: + block_size: 10 + +logger: + name: TensorboardLogger + args: + log_dir: runs + +optimizer: + name: AdamW + args: + lr: 0.0001 + weight_decay: 0.0001 + +model: + name: SpatioTemporalTransformer + args: + n_joints: 32 + input_dim: 4 + d_model: 256 + n_blocks: 3 + n_heads: 8 + d_head: 16 + mlp_dim: 512 + dropout: 0.5 + +runner: + name: Runner + args: + train_batch_size: 16 + val_batch_size: 16 + block_size: 8 + log_gradient_info: false + device: cuda + n_epochs: 100