-
Notifications
You must be signed in to change notification settings - Fork 40
/
dataPipeline.py
78 lines (64 loc) · 2.96 KB
/
dataPipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from DeepJetCore.compiled.c_trainDataGenerator import trainDataGenerator
import numpy as np
class TrainDataGenerator(trainDataGenerator):
def __init__(self,
pad_rowsplits=False,
fake_truth=None,
dict_output=False,
cast_to = None):
trainDataGenerator.__init__(self)
#self.extend_truth_list_by = extend_truth_list_by
self.pad_rowsplits=pad_rowsplits
self.dict_output = dict_output
self.fake_truth = None
self.cast_to = cast_to
if fake_truth is not None:
if isinstance(fake_truth, int):
self.fake_truth = [np.array([0],dtype='float32')
for _ in range(fake_truth)]
elif isinstance(fake_truth, list):
etl={}
for e in fake_truth:
if isinstance(e,str):
etl[e]=np.array([0],dtype='float32')
else:
raise ValueError("TrainDataGenerator: only accepts an int or list of strings to extend truth list")
self.fake_truth = etl
def feedTrainData(self):
for _ in range(self.getNBatches()):
td = self.getBatch()
if self.cast_to is not None:
td.__class__ = self.cast_to
yield td
def feedNumpyData(self):
fnames=[]
tnames=[]
wnames=[]
for b in range(self.getNBatches()):
try:
data = self.getBatch()
if not len(fnames):
fnames = data.getNumpyFeatureArrayNames()
tnames = data.getNumpyTruthArrayNames()
wnames = data.getNumpyWeightArrayNames()
# These calls will transfer data to numpy and delete the respective SimpleArray
# instances for efficiency.
# therefore extracting names etc needs to happen before!
xout = data.transferFeatureListToNumpy(self.pad_rowsplits)
yout = data.transferTruthListToNumpy(self.pad_rowsplits)
wout = data.transferWeightListToNumpy(self.pad_rowsplits)
if self.dict_output:
xout = {k:v for k,v in zip(fnames,xout)}
yout = {k:v for k,v in zip(tnames,yout)}
wout = {k:v for k,v in zip(wnames,wout)}
if self.fake_truth is not None:
yout=self.fake_truth
out = (xout,yout)
if len(wout)>0:
out = (xout,yout,wout)
yield out
except Exception as e:
print("TrainDataGenerator: an exception was raised in batch",b," out of ", self.getNBatches(),', expection: ', e)
raise e
def feedTorchTensors(self):
pass