-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_rgb_simvp.py
49 lines (39 loc) · 1.3 KB
/
train_rgb_simvp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.strategies.ddp import DDPStrategy
from models import *
from data.data_classes import *
# Configs
batch_size = 4
learning_rate = 1e-3
epochs = 100
num_ctx_frames = 1
num_tgt_frames = 9
split_ratio=[0.4, 0.1, 0.5]
hid_s=64
hid_t=256
N_s=4
N_t=8
kernel_sizes=[3,5,7,11]
groups=4
channels = 3
height = 128
width = 128
input_shape = (channels, num_ctx_frames, height, width)
model = SimVP_1to9(input_shape=input_shape,
hid_s=hid_s, hid_t=hid_t,
N_s=N_s, N_t=N_t,
kernel_sizes=kernel_sizes,
groups=groups,
learning_rate=learning_rate)
moving_mnist = TwoColourMovingMNISTDataModule(batch_size,
num_ctx_frames, num_tgt_frames,
split_ratio=split_ratio)
logger = TensorBoardLogger('./logs', 'SimVP')
trainer = pl.Trainer(gpus=4,
strategy=DDPStrategy(find_unused_parameters=False),
max_epochs= epochs,
callbacks=LearningRateMonitor(),
logger=logger)
trainer.fit(model, moving_mnist)