-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_rgb_predrnn.py
46 lines (37 loc) · 1.37 KB
/
train_rgb_predrnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.strategies.ddp import DDPStrategy
from models import *
from data.data_classes import *
# Configs
batch_size = 4
learning_rate = 1e-3
epochs = 100
input_channels=3
num_hidden=[64, 64, 64]
kernel_size=5
stride=1
learning_rate=1e-3
num_ctx_frames=1
num_tgt_frames=9
split_ratio=[0.4, 0.1, 0.5]
model = PredRNN(input_channels=input_channels,
num_hidden=num_hidden,
num_ctx_frames=num_ctx_frames,
num_tgt_frames=num_tgt_frames,
kernel_size=kernel_size,
stride=stride,
learning_rate=learning_rate)
moving_mnist = TwoColourMovingMNISTDataModule(batch_size,
num_ctx_frames,
num_tgt_frames,
split_ratio=split_ratio)
logger = TensorBoardLogger('./logs', 'PredRNN_RGB')
trainer = pl.Trainer(gpus=4,
strategy=DDPStrategy(find_unused_parameters=False),
num_sanity_val_steps=0,
max_epochs= epochs,
callbacks=LearningRateMonitor(),
logger=logger)
trainer.fit(model, moving_mnist)