-
Notifications
You must be signed in to change notification settings - Fork 0
/
batchprocessing.py
320 lines (283 loc) · 12.2 KB
/
batchprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
from functools import wraps
import inspect
import os
import shutil
import json
from typing import Union
from joblib import Parallel, delayed
import glob
from tqdm import tqdm
import numpy as np
import pandas as pd
class BatchProcessor:
def __init__(self,
n_batches: int = 1,
checkpoint_path: str = '',
do_load_cp: bool = True,
n_jobs: int = 1,
progress_bar: bool = True) -> None:
self._n_batches = n_batches
self._checkpoint_path = checkpoint_path
self._do_load_cp = do_load_cp
self._n_jobs = n_jobs
self._progress_bar = progress_bar
def _batch_predict_self(self, method):
if inspect.ismethod(method):
@wraps(method)
def _wrapper(predictor_self, *args, **kwargs):
output = self._batch_predict_func(
predictor_self=predictor_self,
method=method,
args=args,
kwargs=kwargs)
return output
else:
@wraps(method)
def _wrapper(*args, **kwargs):
output = self._batch_predict_func(predictor_self=None,
method=method,
args=args,
kwargs=kwargs)
return output
return _wrapper
def _batch_predict_func(self, predictor_self, method, args, kwargs):
if self._n_batches is None or self._n_batches == 1:
# execute normally
if predictor_self is not None:
output = method(predictor_self, *args, **kwargs)
else:
output = method(*args, **kwargs)
else:
# batch processing:
batches, frst_it = self._get_remaining_batches_and_iterator(kwargs)
self._check_makedir()
last_iter = self._get_last_iter(kwargs)
other_kwargs = self._get_other_kwargs(kwargs)
# execute function for each batch
if self._n_jobs is not None and self._n_jobs > 1:
if self._progress_bar:
Parallel(n_jobs=self._n_jobs)(
delayed(self._iterfunc)(
predictor_self=predictor_self,
x=x,
method=method,
args=args,
other_kwargs=other_kwargs,
i=i+frst_it,
) for i, x in enumerate(tqdm(batches)))
else:
Parallel(n_jobs=self._n_jobs)(
delayed(self._iterfunc)(
predictor_self=predictor_self,
x=x,
method=method,
args=args,
other_kwargs=other_kwargs,
i=i+frst_it,
) for i, x in enumerate(batches))
else:
if self._progress_bar:
for i, x in enumerate(tqdm(batches)):
self._iterfunc(predictor_self=predictor_self,
x=x,
method=method,
args=args,
other_kwargs=other_kwargs,
i=i+frst_it,
)
else:
for i, x in enumerate(batches):
self._iterfunc(predictor_self=predictor_self,
x=x,
method=method,
args=args,
other_kwargs=other_kwargs,
i=i+frst_it,
)
# combine individual batch results into one matrix
# TODO: implement a way to combine individual results when function
# returns multiple values (tuple)
last_iter = self._get_last_iter(kwargs)
# We can pass all_=True here, because we only get here if all
# iterations have been run
results = self._load_result_checkpoints(last_iter=last_iter,
all_=True)
output = pd.concat(results, axis=0, ignore_index=True)
if len(output) == len(kwargs['X']):
self._cleanup_checkpoints()
else:
raise ValueError('Output size is different from input size.')
return output
def _get_remaining_batches_and_iterator(self, kwargs):
batches = self._get_batches(kwargs, n_batches=self._n_batches)
other_kwargs = {
key: value for key, value in kwargs.items() if key != 'X'
}
if self._do_load_cp:
last_iter = self._get_last_iter(parameter_dict=other_kwargs)
else:
last_iter = None
if last_iter is None:
first_iter = 0
else:
first_iter = last_iter + 1
batches = self._get_unprocessed_batches(last_iter, batches)
return batches, first_iter
def _iterfunc(self,
predictor_self,
x,
method,
args,
other_kwargs,
i,
):
if predictor_self is not None:
iter_output = method(predictor_self, *args, X=x, **other_kwargs)
else:
iter_output = method(*args, X=x, **other_kwargs)
self._save_checkpoints(iteration=i,
df=iter_output,
parameter_dict=other_kwargs)
@staticmethod
def _get_unprocessed_batches(last_iter, batches):
return batches[last_iter+1:]
@staticmethod
def _get_batches(kwargs: dict, n_batches: int) -> np.ndarray:
data = kwargs['X']
batches = np.array_split(data, n_batches)
return batches
@staticmethod
def _get_other_kwargs(kwargs):
return {key: value for key, value in kwargs.items() if key != 'X'}
def _get_last_iter(self, parameter_dict):
checkpoint = {}
if self._checkpoint_path is not None:
try:
with open(
os.path.join(
self._checkpoint_path,
'checkpoint.json'),
'r', encoding='utf8') as f:
checkpoint = json.loads(f.read())
if isinstance(checkpoint.get('last_iter'), int):
cp_found = True
else:
print(f'Incompatible checkpoint value {checkpoint}')
print('Starting from beginning')
cp_found = False
except FileNotFoundError as e:
print(e)
print('Checkpoint file not found, starting from beginning')
cp_found = False
else:
cp_found = False
if cp_found:
# check if the different parameters have been used:
if parameter_dict is not None:
for key, value in parameter_dict.items():
if key in checkpoint.keys() and value != checkpoint[key]:
raise ValueError(f'Attempting to continue with \
different parameters. Loaded value of {key} \
is {checkpoint[key]} but you passed {value}.\
Aborting. \
Manually delete the checkpkint directory or \
adjust the parameters to continue.')
else:
checkpoint = {'last_iter': None}
# if parameter_dict is not None:
# checkpoint.update(parameter_dict)
return checkpoint['last_iter']
def _load_result_checkpoints(self, last_iter, all_=False):
df_list = []
if all_:
files = glob.glob(os.path.join(self._checkpoint_path, '*.csv.gz'))
for f in files:
df_list.append(pd.read_csv(f, index_col=0))
else:
if last_iter is not None:
for iteration in range(last_iter+1):
padded_iteration = self._get_padded_iterator(iteration)
df_list.append(pd.read_csv(
os.path.join(self._checkpoint_path,
f'cp_{padded_iteration}.csv.gz'),
index_col=0))
return df_list
def _get_padded_iterator(self, iteration: int):
if not isinstance(self._n_batches, int):
raise ValueError(f'Integer expected for the number of batches \
but received {type(self._n_batches)} instead.')
digits = len(str(self._n_batches))
iteration_padded = str(iteration).zfill(digits)
return iteration_padded
def _save_checkpoints(self,
iteration: int,
df: Union[pd.DataFrame, np.ndarray],
parameter_dict: dict = None,
) -> None:
padded_iteration = self._get_padded_iterator(iteration)
checkpoint = {'last_iter': iteration}
if parameter_dict is not None:
checkpoint.update(parameter_dict)
if not isinstance(df, pd.DataFrame):
df = pd.DataFrame(df)
df.to_csv(os.path.join(self._checkpoint_path,
f'cp_{padded_iteration}.csv.gz'))
# export the checkpoint iterator last
with open(os.path.join(self._checkpoint_path, 'checkpoint.json'),
'w',
encoding='utf8') as f:
f.write(json.dumps(checkpoint))
def _check_makedir(self) -> None:
"""Check if directory exists and creates it otherwise
Args:
path (str): Path to directory
"""
if not os.path.isdir(self._checkpoint_path):
os.makedirs(self._checkpoint_path, exist_ok=True)
def _cleanup_checkpoints(self) -> None:
"""Delete the checkpoint folder and all of its contents
Args:
path (str): Checkpoint_folder
"""
if os.path.isdir(self._checkpoint_path):
shutil.rmtree(self._checkpoint_path)
# -------------------------------------------------------------------------
# classmethods:
@classmethod
def batch_predict_auto(cls, method):
@wraps(method)
def _wrapper(predictor_self, *args, **kwargs):
checkpoint_path = kwargs.get('checkpoint_path')
n_batches = kwargs.get('n_batches')
n_jobs = kwargs.get('n_jobs')
do_load_cp = kwargs.get('do_load_cp')
instance = cls(checkpoint_path=checkpoint_path,
n_batches=n_batches,
n_jobs=n_jobs,
do_load_cp=do_load_cp)
if checkpoint_path is not None:
instance._checkpoint_path = checkpoint_path
if n_batches is not None:
instance._n_batches = n_batches
if n_jobs is not None:
instance._n_jobs = n_jobs
if do_load_cp is not None:
instance._do_load_cp = do_load_cp
return instance._batch_predict_func(predictor_self,
method,
args,
kwargs)
return _wrapper
@classmethod
def batch_predict(cls,
checkpoint_path: str,
n_batches: int = 10,
n_jobs: int = 1,
do_load_cp: int = False):
def decorator(method):
instance = cls(n_jobs=n_jobs,
n_batches=n_batches,
checkpoint_path=checkpoint_path,
do_load_cp=do_load_cp)
return instance._batch_predict_self(method)
return decorator