-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_net_train.py
118 lines (89 loc) · 4.64 KB
/
image_net_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from modules.models import *
from modules.dataset import *
from modules.process import *
from torchvision import transforms
import numpy as np
import argparse, pickle, torch, os
def PrintInputVariable(modelPath, learningRate, imageSize, batchSize, epochs, deviceName, rootPath, classNamePath):
print("--- input variables ---")
print("model path : %s"%(modelPath if os.path.isfile(modelPath) else "None"))
print("learning rate : %f"%learningRate)
print("image size : %d"%imageSize)
print("batch size : %d"%batchSize)
print("epochs : %d"%epochs)
print("use cuda device : %s"%(deviceName if torch.cuda.is_available() else "cpu"))
print("data root path : %s"%rootPath)
print("class name path : %s"%classNamePath)
print()
return None
def main():
parser = argparse.ArgumentParser(description="variable file select")
parser.add_argument("--csv", type=str, help="variable csv file")
args = parser.parse_args()
variableFile = args.csv
variableFileName = variableFile.split('.')[0]
# get argments
inputValue = {}
for key, value in np.loadtxt(variableFile, dtype=str, delimiter=','):
inputValue[key] = value
modelPath = str(inputValue["modelPath"])
learningRate = float(inputValue["learningRate"])
imageSize = int(inputValue["imageSize"])
batchSize = int(inputValue["batchSize"])
epochs = int(inputValue["epochs"])
deviceName = str(inputValue["deviceName"])
rootPath = str(inputValue["rootPath"])
classNamePath = str(inputValue["classNamePath"])
# model setting
torch.multiprocessing.set_sharing_strategy('file_system')
device = torch.device(deviceName if torch.cuda.is_available() else "cpu")
if "cuda" in deviceName:
os.environ['CUDA_LAUNCH_BLOCKING'] = deviceName.split(':')[-1]
model = ResUNext()
if os.path.isfile(modelPath):
model.load_state_dict(torch.load(modelPath))
model = model.to(device)
PrintInputVariable(modelPath, learningRate, imageSize, batchSize, epochs, deviceName, rootPath, classNamePath)
# prepare dataset
classDic = VariableDumpSaveNLoad("classDic", GetClassDictionary, rootPath+"train")
trainDic = VariableDumpSaveNLoad("trainDic", GetDataInfoDictionary, rootPath+"train", classDic, len(classDic))
validDic = VariableDumpSaveNLoad("validDic", GetDataInfoDictionary, rootPath+"val", classDic, len(classDic))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
dataTransform = transforms.Compose([
transforms.Resize((imageSize, imageSize)),
transforms.ToTensor(),
normalize,
])
# learning model
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)
processRecord = {"loss" : {}, "accuracy" : {}}
for epoch in range(1, epochs+1):
if not os.path.isdir("results/"):
os.mkdir("results/")
if not os.path.isdir("results/"+variableFileName):
os.mkdir("results/"+variableFileName)
trainList = MakeEqualDatasetList(trainDic, dataShuffle=True)
validList = MakeEqualDatasetList(validDic, dataShuffle=True)
trainDataLoader = MakeDataLoader(trainList, batchSize=batchSize*2, dataTransform=dataTransform)
validDataLoader = MakeDataLoader(validList, batchSize=batchSize, dataTransform=dataTransform)
processRecord["loss"][epoch] = {}
processRecord["accuracy"][epoch] = {}
_loss, _acc, _dic = DeeplearningProcessing(model, device, trainDataLoader, criterion, optimizer, epoch, training=True, announceBatchStep=200)
processRecord["loss"][epoch]["train"] = _loss
processRecord["accuracy"][epoch]["train"] = _acc
pickle.dump(_dic, open("results/%s/%s_loss_%.4f_acc_%.2f%%_epoch_%d.dump"%(variableFileName, "train", _loss, _acc, epoch), "wb"))
_loss, _acc, _dic = DeeplearningProcessing(model, device, validDataLoader, criterion, optimizer, epoch, training=False, announceBatchStep=50)
processRecord["loss"][epoch]["valid"] = _loss
processRecord["accuracy"][epoch]["valid"] = _acc
pickle.dump(_dic, open("results/%s/%s_loss_%.4f_acc_%.2f%%_epoch_%d.dump"%(variableFileName, "valid", _loss, _acc, epoch), "wb"))
torch.save(model.state_dict(), "results/%s/loss_%.4f_acc_%.2f%%_epoch_%d.pt"%(variableFileName, _loss, _acc, epoch))
RecordCSV(variableFileName, processRecord, imageSize, classNamePath)
PlotProcess(processRecord)
pickle.dump(processRecord, open("results/%s/processRecord.dump"%variableFileName, "wb"))
return None
if __name__ == "__main__":
main()