-
Notifications
You must be signed in to change notification settings - Fork 40
/
TrainData.py
121 lines (95 loc) · 4.12 KB
/
TrainData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
'''
Created on 20 Feb 2017
@author: jkiesele
New (post equals 2.1) version
'''
import os
import numpy as np
import logging
from DeepJetCore.compiled.c_trainData import trainData
from DeepJetCore.SimpleArray import SimpleArray
import time
def fileTimeOut(fileName, timeOut):
'''
simple wait function in case the file system has a glitch.
waits until the dir, the file should be stored in/read from, is accessible
again, or the the timeout
'''
filepath=os.path.dirname(fileName)
if len(filepath) < 1:
filepath = '.'
if os.path.isdir(filepath):
return
counter=0
print('file I/O problems... waiting for filesystem to become available for '+fileName)
while not os.path.isdir(filepath):
if counter > timeOut:
raise Exception('...file could not be opened within '+str(timeOut)+ ' seconds')
counter+=1
time.sleep(1)
#inherit from cpp class, just slim wrapper
class TrainData(trainData):
'''
Base class for batch-wise training of the DNN
'''
def __init__(self):
trainData.__init__(self)
def getInputShapes(self):
print('TrainData:getInputShapes: Deprecated, use getNumpyFeatureShapes instead')
return self.getNumpyFeatureShapes()
def readIn(self,fileprefix,shapesOnly=False):
print('TrainData:readIn deprecated, use readFromFile')
self.readFromFile(fileprefix,shapesOnly)
def _convertToCppType(self,a,helptext):
saout=None
if str(type(a)) == "<class 'DeepJetCore.SimpleArray.SimpleArray'>":
saout = a.sa
elif str(type(a)) == "<type 'numpy.ndarray'>" or str(type(a)) == "<class 'numpy.ndarray'>":
rs = np.array([])
a = SimpleArray(a,rs)
saout = a.sa
else:
raise ValueError("TrainData._convertToCppType MUST produce either a list of numpy arrays or a list of DeepJetCore simpleArrays!")
if saout.hasNanOrInf():
raise ValueError("TrainData._convertToCppType: the "+helptext+" array "+saout.name()+" has NaN or inf entries")
return saout
def _store(self, x, y, w):
for xa in x:
self.storeFeatureArray(self._convertToCppType(xa, "feature"))
x = [] #collect garbage
for ya in y:
self.storeTruthArray(self._convertToCppType(ya, "truth"))
y = []
for wa in w:
self.storeWeightArray(self._convertToCppType(wa, "weight"))
w = []
def readFromSourceFile(self,filename, weighterobjects={}, istraining=False, **kwargs):
x,y,w = self.convertFromSourceFile(filename, weighterobjects, istraining, **kwargs)
self._store(x,y,w)
################# functions to be defined by the user
def createWeighterObjects(self, allsourcefiles):
'''
Will be called on the full list of source files once.
Can be used to create weighter objects or similar that can
then be applied to each individual conversion.
Should return a dictionary
'''
return {}
### perform a simple and quick check if the file is not corrupt. Can be called in advance to conversion
# return False if file is corrupt
def fileIsValid(self, filename):
return True
### either of the following need to be defined
## if direct writeout is useful
def writeFromSourceFile(self, filename, weighterobjects, istraining, outname):
self.readFromSourceFile(filename, weighterobjects, istraining)
self.writeToFile(outname)
## otherwise only define the conversion rule
# returns a list of numpy arrays OR simpleArray (mandatory for ragged tensors)
def convertFromSourceFile(self, filename, weighterobjects, istraining):
return [],[],[]
## defines how to write out the prediction
# must not use any of the stored arrays, only the inputs
# optionally it can return the output file name to be added to a list of output files
def writeOutPrediction(self, predicted, features, truth, weights, outfilename, inputfile):
return None