-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_resnet50_script.py
26 lines (23 loc) · 798 Bytes
/
train_resnet50_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from data import CrowdDataSet
from data import default_train_transforms, default_val_transforms
import numpy as np
from models import ResNetTransfer
from trainer import train
import torch.optim as optim
import torch.nn as nn
import torch
loaders = {
"train": CrowdDataSet(
'part_A/train_data', default_train_transforms()
),
"val": CrowdDataSet(
'part_A/test_data', default_val_transforms()
)
}
model = ResNetTransfer()
criterion = nn.MSELoss()
lr = 1e-6
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses, train_r2, val_losses, val_r2 = train(model, loaders, criterion, optimizer, 200)
torch.save(model, 'saved_models/resnet50_den_map')
np.save(f"loss_experiments/resnet50_denmap/resnet50_den_losses", (train_losses, train_r2, val_losses, val_r2))