forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NativeOp.py
3822 lines (3427 loc) · 156 KB
/
NativeOp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Generic interface which automatically creates:
* CPU and GPU op
* inplace and not inplace
* grad variants
"""
import copy
import os
import sys
import numpy
import theano
import theano.sandbox.cuda
import theano.tensor as T
from theano.compile import optdb
from theano import gof
from theano.gof.opt import OpSub
from Util import make_hashable, make_dll_name, escape_c_str
from TheanoUtil import try_register_gpu_opt, make_var_tuple, softmax
PY3 = sys.version_info[0] >= 3
if PY3:
unicode = str
long = int
class NativeOpBaseMixin(object):
"""
The purpose of having this as a separate base class is to make this independent of any Theano specific
functionality so that we can also use this base for example for TensorFlow.
"""
def __init__(self, in_info, out_info,
c_fw_code, c_bw_code=None, c_extra_support_code=None, code_version=None, cpu_support=True,
grad_input_map=None, name=None):
"""
:param list[dict(str)] in_info: each dict describes one input var.
attribs in the dict:
int ndim: the ndim.
tuple shape: tuple and can contain None for specific dimensions.
optional attribs:
str dtype: "float32" by default.
bool need_contiguous: false by default.
int want_inplace: -1 by default. try to optimize to destroy input, on output-index.
"dummy_out" is a special value which will add another output.
bool is_inplace: false by default. whether the optimization was applied.
str gradient: can be "disconnected". see grad().
bool bw_input: True by default. add this param to the bw input.
other attribs are just ignored.
:param list[dict(str)] out_info: like in_info.
slightly different behavior for:
shape: we also allow refs to the in_info in the form (in-idx,dim). see infer_shape().
need_contiguous/want_inplace: used for bw, in case for bw_input == True.
:param str c_fw_code: C code for forward pass
:param str|dict[str] c_extra_support_code: C support code (for c_support_code)
:param str|None c_bw_code: C code for backward pass (for gradient)
:param tuple[int] code_version: will be returned by c_code_cache_version.
:param bool cpu_support:
:param tuple[int]|callable grad_input_map: selection of grad inputs.
by default, we get all inputs + all outputs + all grad outputs.
:param str name: name
"""
assert isinstance(in_info, (list, tuple))
assert isinstance(out_info, (list, tuple))
in_info, out_info, num_dummy_outs = self._resolve_want_inplace_dummy(in_info, out_info)
self.in_info = make_hashable(in_info)
self.out_info = make_hashable(out_info)
self.num_dummy_outs = num_dummy_outs
self.c_fw_code = c_fw_code
self.c_bw_code = c_bw_code
self.c_extra_support_code = self._reduce_c_extra_support_code(c_extra_support_code)
self.code_version = code_version or ()
self.cpu_support = cpu_support
self.name = name or "<anonNativeOp>"
self.grad_input_map = self._convert_grad_input_map(grad_input_map, len(in_info) + len(out_info) * 2)
self.destroy_map = self._construct_destroy_map(in_info)
@classmethod
def _resolve_want_inplace_dummy(cls, in_info, out_info):
in_info = [dict(info) for info in in_info] # deep copy, don't modify original
out_info = list(out_info) # copying list is enough here
num_dummy_outs = 0
for in_idx, info in enumerate(in_info):
if info.get("want_inplace", None) == "dummy_out":
num_dummy_outs += 1
dummy_out_idx = len(out_info)
dummy_out = {"ndim": info["ndim"],
"shape": [(in_idx, i) for i in range(info["ndim"])],
"dtype": info.get("dtype", "float32"),
"name": "dummy_out_%i" % num_dummy_outs}
out_info += [dummy_out]
info["want_inplace"] = dummy_out_idx
return in_info, out_info, num_dummy_outs
@classmethod
def _reduce_c_extra_support_code(cls, c):
if c is None:
return ""
if isinstance(c, dict):
c = [v for (k, v) in sorted(c.items())]
if isinstance(c, (list, tuple)):
c = "\n".join([v + "\n\n" for v in c])
assert isinstance(c, (str, unicode))
return c
@classmethod
def _construct_destroy_map(cls, in_info):
destroy_map = {}
for in_idx, info in enumerate(in_info):
want_inplace = info.get("want_inplace", -1)
assert isinstance(want_inplace, (int, long))
if want_inplace >= 0 and info.get("is_inplace", False):
out_idx = want_inplace
# http://deeplearning.net/software/theano/extending/inplace.html
# https://github.com/Theano/Theano/issues/3506
# It's strange that we must mark which output operates on which input -
# I would expect that it must only know which inputs are destroyed.
assert out_idx not in destroy_map, "Theano cannot handle that yet"
destroy_map[out_idx] = [in_idx]
return destroy_map
@classmethod
def _convert_grad_input_map(cls, gi_map, num_params):
"""
:param gi_map: see grad_input_map argument for self.__init__
:param int num_params:
:return: tuple of int
:rtype: tuple[int]
"""
if gi_map is None:
gi_map = tuple(range(num_params))
if callable(gi_map):
gi_map = gi_map(*range(num_params))
if isinstance(gi_map, list):
gi_map = tuple(gi_map)
assert isinstance(gi_map, tuple)
return gi_map
def _filter_grad_inputs(self, inputs):
"""
:param list[T] inputs: inputs + outputs + output_grads. can be either symbolic tensors or info dicts
:return: filtered list, via self.grad_input_map
:rtype: list[T]
"""
assert len(inputs) == len(self.in_info) + len(self.out_info) * 2
return [inputs[i] for i in self.grad_input_map]
def infer_shape(self, node, input_shapes):
assert len(input_shapes) == len(self.in_info)
out_shapes = []
for info in self.out_info:
out_shape = list(info["shape"])
for idx, s in enumerate(out_shape):
if isinstance(s, tuple): # we interpret this as a reference to input shapes
assert len(s) == 2, "dim %r invalid in info %r" % (s, info)
assert 0 <= s[0] < len(input_shapes), "dim %r invalid in info %r" % (s, info)
assert 0 <= s[1] < self.in_info[s[0]]["ndim"], "dim idx %r invalid in input %i %r, info %r" % (s[1], s[0], self.in_info[s[0]], info)
out_shape[idx] = input_shapes[s[0]][s[1]]
assert not any([s is None for s in out_shape]), "out_shape %r, out_info %r" % (out_shape, self.out_info)
out_shapes += [tuple(out_shape)]
return out_shapes
@classmethod
def _bw_in_var_info(cls, info):
"""
:param dict[str] info:
:return: updated info dict for the gradient (bwd) as input
:rtype: dict[str]
"""
if "bw_in_var" in info:
info = dict(info)
info.update(info.pop("bw_in_var"))
return info
@classmethod
def _bw_grad_var_info(cls, info):
"""
:param dict[str] info: backward gradient input for one of our outputs
:return: updated info dict for the gradient (bwd) as input
:rtype: dict[str]
"""
info = dict(info)
if "bw_grad_var" in info:
info.update(info.pop("bw_grad_var"))
if "name" in info:
info["name"] = "D_" + info["name"]
return info
def kwargs_for_grad_op(self):
"""
:returns: the kwargs for creating a NativeOp for the gradient op. e.g. includes in_info, out_info, etc
:rtype: dict[str]
Note: The inputs of the gradient are by default: fwd_op.inputs + fwd_op.outputs + output_grads.
We filter them via self._filter_grad_inputs.
"""
# Inputs: inputs + outputs + output_grads, where outputs = op(inputs),
# i.e. we might reuse some of the calculation.
in_info = [self._bw_in_var_info(info) for info in self.in_info]
in_info += [self._bw_in_var_info(info) for info in self.out_info]
in_info += [self._bw_grad_var_info(info) for info in self.out_info]
in_info = self._filter_grad_inputs(in_info)
in_idx_rev = {v: k for (k, v) in enumerate(self.grad_input_map)}
# Outputs: All like original inputs. Filter our the disconnected.
out_info = [info.copy() for info in self.in_info]
for idx, info in enumerate(out_info):
info.pop("shape")
if "bw_out_var" in info:
info.update(info["bw_out_var"])
if "shape" not in info:
# Refer to input shapes. See infer_shape().
info["shape"] = [(in_idx_rev[idx], i) for i in range(info["ndim"])]
out_info = [info for info in out_info if info.get("gradient", "") != "disconnected"]
return dict(
name="GradOf%s" % self.name,
in_info=in_info,
out_info=out_info,
c_fw_code=self.c_bw_code,
c_extra_support_code=self.c_extra_support_code,
code_version=self.code_version,
cpu_support=self.cpu_support
)
def make_results_of_gradient(self, grad_op_outputs, disconnected_type=None):
"""
:param list[T] grad_op_outputs: this is already with dummy outputs removed
:param S disconnected_type:
:return: gradient for each input of our op
:rtype: list[T|S]
"""
if disconnected_type is None:
disconnected_type = lambda: None
grad_op_outputs = list(grad_op_outputs)
results = []
for info in self.in_info:
if info.get("gradient", "") == "disconnected":
results += [disconnected_type()]
else:
results += grad_op_outputs[:1]
grad_op_outputs = grad_op_outputs[1:]
assert len(grad_op_outputs) == 0
assert len(results) == len(self.in_info)
return results
class NativeOp(theano.Op, NativeOpBaseMixin):
"""
We wrap some C code which can define a forward pass
and optionally a backward pass (for gradient calculation).
The C code should be Numpy and CUDA compatible. See NativeOp.cpp.
We also support inplace operations, i.e. we can operate inplace on some inputs.
You can define in a flexible way all the inputs and the outputs.
See __init__() for the details.
All output variables are created automatically with the right shape
but their content is not initialized,
except when its used by some input variable as the inplace output
- in that case, it is either the input variable or it has a copy of its data.
"""
__props__ = ("in_info", "out_info",
"c_fw_code", "c_bw_code", "c_extra_support_code", "code_version",
"grad_input_map", "name",
"custom_grad")
def __init__(self, custom_grad=None, **kwargs):
"""
:param function custom_grad: if given, will use this instead for self.grad
:param dict[str] kwargs: all passed to NativeOpBaseMixin
"""
theano.Op.__init__(self)
NativeOpBaseMixin.__init__(self, **kwargs)
self.custom_grad = custom_grad
def __str__(self):
return "%s{%s,%s}" % (
self.__class__.__name__,
self.name,
"inplace" if self.destroy_map else "no_inplace")
@classmethod
def as_tensor_var(cls, v):
return theano.tensor.as_tensor_variable(v)
@classmethod
def tensor_type(cls, dtype, ndim):
return T.TensorType(dtype=dtype, broadcastable=(False,) * ndim)
@classmethod
def contiguous(cls, v):
from TheanoUtil import Contiguous
assert isinstance(v, theano.Variable)
if getattr(v, 'owner', None):
assert isinstance(v.owner, theano.Apply)
if isinstance(v.owner.op, Contiguous.__base__):
return v
return Contiguous()(v)
def _convert_input_var(self, v, info):
v = self.as_tensor_var(v)
dtype = "float32" # Theano on GPU only supports float32 ... # info.get("dtype", "float32")
if v.dtype != dtype:
v = T.cast(v, dtype)
if v.ndim != info["ndim"]:
raise TypeError("input var ndim %i does not match with info %r" % (v.ndim, info))
if info.get("need_contiguous", False):
v = self.contiguous(v)
return v
def grad(self, inputs, output_grads):
if self.custom_grad:
return self.custom_grad(self, inputs, output_grads)
if not self.c_bw_code:
# Unknown how to calculate gradient.
return [T.DisconnectedType()() for inp in inputs]
assert len(self.in_info) == len(inputs)
assert len(self.out_info) == len(output_grads)
# Some of output_grads might be of disconnected type.
out_shapes = self.infer_shape(None, [v.shape for v in inputs])
assert len(out_shapes) == len(output_grads)
for i, out_grad in enumerate(output_grads):
if isinstance(out_grad.type, T.DisconnectedType):
output_grads[i] = T.zeros(out_shapes[i], dtype="float32")
kwargs_for_grad = self.kwargs_for_grad_op()
grad_op = self.__class__(**kwargs_for_grad)
grad_inputs = inputs + list(make_var_tuple(self(*inputs))) + output_grads
grad_inputs = self._filter_grad_inputs(grad_inputs)
assert len(grad_op.in_info) == len(grad_inputs)
grad_outputs = make_var_tuple(grad_op(*grad_inputs))
assert len(grad_op.out_info) == len(grad_outputs)
if grad_op.num_dummy_outs > 0:
grad_outputs = grad_outputs[:-grad_op.num_dummy_outs] # remove any dummy outputs
def print_fn(op, x):
import numpy
first = x[(0,) * x.ndim]
stats = (first, x.shape, numpy.min(x), numpy.max(x), numpy.mean(x), numpy.std(x),
numpy.isinf(x).any(), numpy.isnan(x).any())
print(op.message, "first/shape/min/max/mean/std/any-inf/any-nan:", stats)
#input_grads = [theano.printing.Print("in grad %i" % i, global_fn=print_fn)(v)
# for (i, v) in enumerate(input_grads)]
return self.make_results_of_gradient(grad_outputs, disconnected_type=T.DisconnectedType())
def connection_pattern(self, node):
assert len(node.inputs) == len(self.in_info)
pattern = [[info.get("gradient", "") != "disconnected"] * len(self.out_info)
for info in self.in_info]
return pattern
def make_node(self, *args):
assert len(args) == len(self.in_info)
args = [self._convert_input_var(arg, info) for arg, info in zip(args, self.in_info)]
outputs = [self.tensor_type(dtype=info.get("dtype", "float32"), ndim=info["ndim"])()
for info in self.out_info]
return theano.Apply(self, args, outputs)
def perform(self, node, inputs, output_storage):
raise NotImplementedError("NativeOp: no pure Python implementation, only C implementation")
def c_code_cache_version(self):
return self.code_version
def c_support_code(self):
base_src = open(os.path.dirname(__file__) + "/NativeOp.cpp").read()
return "\n\n".join([
T.blas.blas_header_text(),
"#define CUDA 0",
base_src,
self.c_extra_support_code])
def c_libraries(self):
return T.blas.ldflags()
def c_compile_args(self):
return T.blas.ldflags(libs=False, flags=True)
def c_lib_dirs(self):
return T.blas.ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
return T.blas.ldflags(libs=False, include_dir=True)
def c_code(self, node, name, inputs, outputs, sub):
assert len(inputs) == len(self.in_info)
assert len(outputs) == len(self.out_info)
return """
{
int n_inputs = %(n_inputs)i, n_outputs = %(n_outputs)i;
Ndarray* inputs[] = {%(input_var_names_str)s};
Ndarray** outputs[] = {%(output_var_names_str)s};
int in_ndims[] = {%(input_ndims_str)s};
int out_ndims[] = {%(output_ndims_str)s};
Ndarray_DIM_Type output_shapes_flat[] = {%(output_shapes_flat_str)s};
int in_want_inplace[] = {%(input_want_inplace_str)s};
bool in_is_inplace[] = {%(input_is_inplace_str)s};
// Check if we can reuse any preallocated output.
// Reset those which we cannot reuse.
{
int out_shape_idx = 0;
for(int i = 0; i < n_outputs; ++i) {
assert_cmp(out_shape_idx + out_ndims[i], <=, ARRAY_LEN(output_shapes_flat));
if(*outputs[i]) {
bool can_reuse = true;
for(int j = 0; j < out_ndims[i]; ++j)
if(output_shapes_flat[out_shape_idx + j] != Ndarray_DIMS(*outputs[i])[j]) {
can_reuse = false;
break;
}
if(!can_reuse)
Py_CLEAR(*outputs[i]);
}
out_shape_idx += out_ndims[i];
}
assert_cmp(out_shape_idx, ==, ARRAY_LEN(output_shapes_flat));
}
// Maybe reuse or otherwise copy input into output vars.
for(int i = 0; i < n_inputs; ++i)
if(in_want_inplace[i] >= 0) {
assert_cmp(in_want_inplace[i], <, n_outputs);
Py_XDECREF(*outputs[in_want_inplace[i]]);
if(in_is_inplace[i]) {
*(outputs[in_want_inplace[i]]) = inputs[i];
Py_INCREF(inputs[i]);
} else {
*(outputs[in_want_inplace[i]]) = (Ndarray*) Ndarray_Copy(inputs[i]);
if(!*(outputs[in_want_inplace[i]])) %(fail)s;
inputs[i] = *(outputs[in_want_inplace[i]]); // reset with copy
}
}
// Init the remaining output vars. Note that they are initialized randomly!
{
int out_shape_idx = 0;
for(int i = 0; i < n_outputs; ++i) {
assert(out_shape_idx + out_ndims[i] <= ARRAY_LEN(output_shapes_flat));
if(*(outputs[i])) {
for(int j = 0; j < out_ndims[i]; ++j)
// If this fails, we maybe have reused an input which has an invalid shape.
assert_cmp(output_shapes_flat[out_shape_idx + j], ==, Ndarray_DIMS(*outputs[i])[j]);
}
else {
*(outputs[i]) = (Ndarray*) Ndarray_NewDims(out_ndims[i], &output_shapes_flat[out_shape_idx]);
if(!*(outputs[i])) %(fail)s;
}
out_shape_idx += out_ndims[i];
}
assert_cmp(out_shape_idx, ==, ARRAY_LEN(output_shapes_flat));
}
// And the user C code starts here.
// --------------------------------
%(c_code)s;
}
""" % {
'name': name, 'fail': sub['fail'],
'op_name': escape_c_str(self.name),
'c_code': self.c_fw_code % {'fail': sub['fail']},
'n_inputs': len(inputs), 'n_outputs': len(outputs),
'input_var_names_str': ", ".join(["%s" % inp for inp in inputs]),
'output_var_names_str': ", ".join(["&%s" % out for out in outputs]),
'input_ndims_str': ', '.join(["%i" % info["ndim"] for info in self.in_info]),
'output_ndims_str': ', '.join(["%i" % info["ndim"] for info in self.out_info]),
'output_shapes_flat_str':
', '.join([(("%i" % s) if isinstance(s, (int, long))
else "Ndarray_DIMS(inputs[%i])[%i]" % s)
for info in self.out_info for s in info["shape"]]),
"input_want_inplace_str": ", ".join([str(int(info.get("want_inplace", -1)))
for info in self.in_info]),
"input_is_inplace_str": ", ".join([str(int(info.get("is_inplace", False)))
for info in self.in_info])
}
class GpuNativeOp(NativeOp, theano.sandbox.cuda.GpuOp):
@classmethod
def as_tensor_var(cls, v):
from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable
return as_cuda_ndarray_variable(v)
@classmethod
def tensor_type(cls, dtype, ndim):
from theano.sandbox.cuda import CudaNdarrayType
if dtype != "float32":
print("%s: WARNING: cannot handle type %r, will use float32 instead" % ("GpuNativeOp", dtype))
dtype = "float32"
return CudaNdarrayType(dtype=dtype, broadcastable=(False,) * ndim)
@classmethod
def contiguous(cls, v):
from theano.sandbox.cuda.basic_ops import gpu_contiguous
assert isinstance(v, (theano.sandbox.cuda.CudaNdarrayVariable, theano.sandbox.cuda.CudaNdarrayConstant))
if getattr(v, 'owner', None):
assert isinstance(v.owner, theano.Apply)
if v.owner.op == gpu_contiguous:
return v
return gpu_contiguous(v)
def c_support_code(self):
src = open(os.path.dirname(__file__) + "/NativeOp.cpp").read()
return "\n\n".join([
"#define CUDA 1",
src,
self.c_extra_support_code,
"// end of c_support_code\n\n\n"])
@gof.local_optimizer([NativeOp], inplace=True)
def inplace_NativeOp(node):
if isinstance(node.op, NativeOp) and not node.op.destroy_map:
kwargs = {k: getattr(node.op, k) for k in node.op.__props__}
# TODO: We could try to make each input inplace individually.
# What we do now is just to try to make all inplace.
kwargs["in_info"] = [dict(info) for info in node.op.in_info]
any_inplace = False
for info in kwargs["in_info"]:
if info.get("want_inplace", -1) >= 0:
any_inplace = True
info["is_inplace"] = True
if not any_inplace:
return False
new_op = node.op.__class__(**kwargs)
from TheanoUtil import make_var_tuple
new_v = make_var_tuple(new_op(*node.inputs))
return new_v
return False
try:
optdb.register('inplace_NativeOp',
gof.TopoOptimizer(inplace_NativeOp
, failure_callback=gof.TopoOptimizer.warn_inplace
),
60, 'fast_run', 'inplace')
except ValueError: # can happen if it was already registered before, e.g. when we reload the module
pass
@try_register_gpu_opt(NativeOp)
def local_gpu_NativeOp(node):
if isinstance(node.op, NativeOp):
# see also: https://github.com/Theano/Theano/blob/master/theano/sandbox/cuda/opt.py
from theano.sandbox.cuda import host_from_gpu, gpu_from_host, as_cuda_ndarray_variable
args = node.inputs
if any([(x.owner and x.owner.op == host_from_gpu) for x in args]):
gpu_op = GpuNativeOp(**{key: getattr(node.op, key) for key in node.op.__props__})
args = [x.owner.inputs[0] if (x.owner and x.owner.op == host_from_gpu) else x
for x in args]
from TheanoUtil import make_var_tuple
outputs = make_var_tuple(gpu_op(*args))
return [host_from_gpu(out) for out in outputs]
class NativeOpGenBase:
"""
Base interface for op generation.
See NativeOp.__init__() for attribs.
"""
in_info = None # type: tuple[dict[str]]
out_info = None # type: tuple[dict[str]]
c_fw_code = None # type: str
c_bw_code = None # type: str
c_extra_support_code = None # type: dict[str,str]
code_version = None # type: tuple[int]|int
grad_input_map = None
custom_grad = None
cpu_support = True
def make_op(self):
assert self.in_info is not None
assert self.out_info is not None
assert self.c_fw_code is not None
return NativeOp(in_info=self.in_info, out_info=self.out_info,
c_fw_code=self.c_fw_code, c_bw_code=self.c_bw_code,
c_extra_support_code=self.c_extra_support_code,
grad_input_map=self.grad_input_map,
name=self.__class__.__name__,
custom_grad=self.custom_grad)
@classmethod
def map_layer_inputs_to_op(cls, *inputs):
return inputs
@classmethod
def map_layer_output_from_op(cls, *outputs):
return outputs[0]
class LstmGenericBase(NativeOpGenBase):
"""
inputs:
:param Z: {input,output,forget} gate + cell state. 3d (time,batch,dim*4)
:param V_h: recurrent matrix. 2d (dim,dim*4)
:param c: initial cell state. 2d (batch,dim)
:param i: index. 2d (time,batch) -> 0 or 1
outputs:
:param Y: output. 3d (time,batch,dim)
:param H: gates and cell state. 3d (time,batch,dim*4)
:param d: final cell state. 2d (batch,dim)
"""
in_info = (
{"name": "Z", "ndim": 3, "shape": (None, None, None), "need_contiguous": True,
"want_inplace": 1,
"bw_out_var": {"shape": ((2, 0), (2, 1), (0, 1))}}, # see grad_input_map() for indices
{"name": "V_h", "ndim": 2, "shape": (None, None), "need_contiguous": True},
{"name": "c", "ndim": 2, "shape": (None, None), "need_contiguous": True},
{"name": "i", "ndim": 2, "shape": (None, None), "need_contiguous": True,
"gradient": "disconnected"}
)
out_info = (
{"name": "Y", "ndim": 3, "shape": ((0, 0), (0, 1), (1, 0)), "need_contiguous": True,
"bw_grad_var": {"want_inplace": "dummy_out"}},
{"name": "H", "ndim": 3, "shape": ((0, 0), (0, 1), (0, 2)), "need_contiguous": True,
"bw_in_var": {"want_inplace": 0}},
{"name": "d", "ndim": 2, "shape": ((2, 0), (2, 1)), "need_contiguous": True}
)
@classmethod
def grad_input_map(cls, Z, V_h, c, i, Y, H, d, DY, DH, Dd):
return (V_h, c, i, Y, H, DY, Dd)
@classmethod
def map_layer_inputs_to_op(cls, Z, V_h, i):
assert Z.ndim == 3
assert V_h.ndim == 2
assert i.ndim == 2
n_batch = Z.shape[1]
n_out = V_h.shape[0]
c = T.zeros((n_batch, n_out), dtype="float32")
return Z, V_h, c, i
c_extra_support_code = {
"lstm_kernel": """
DEF_KERNEL
void lstm_kernel(float* data, const float* old_state, bool old_state_strided,
float* output, float* state_out, int n_cells, int n_batch, const float* i) {
//layout:
//data[0*n_cells..1*n_cells-1] : cell state
//data[1*n_cells..2*n_cells-1] : input gate
//data[2*n_cells..3*n_cells-1] : forget gate
//data[3*n_cells..4*n_cells-1] : output gate
//output[0*n_cells..1*n_cells-1]: cell output
//repeated for every mini-batch
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_cells * n_batch) {
int batch_idx = idx / n_cells;
int start = batch_idx * 4 * n_cells + idx % n_cells;
float i_batch = i[batch_idx];
//input, forget and output gates
float inpGate = 1.f / (1.f + expf(-data[start + n_cells]));
float fgtGate = 1.f / (1.f + expf(-data[start + 2 * n_cells]));
float outGate = 1.f / (1.f + expf(-data[start + 3 * n_cells]));
float state = inpGate * tanhf(data[start]);
float old_state_batch = old_state_strided ? old_state[start] : old_state[idx];
state += fgtGate * old_state_batch;
state = state * i_batch + old_state_batch * (1.f - i_batch);
//cell output
output[idx] = outGate * tanhf(state) * i_batch;
data[start] = state;
data[start + n_cells] = inpGate;
data[start + 2 * n_cells] = fgtGate;
data[start + 3 * n_cells] = outGate;
if(state_out)
state_out[idx] = state;
idx += gridDim.x * blockDim.x;
}
}
""",
"lstm_bwd_kernel": """
DEF_KERNEL
void lstm_bwd_kernel(
float* delta, float* epsilon, const float* next_epsilon, const float* old_state,
bool old_state_strided, const float* Y, int n_cells, int n_batch, const float* i) {
//layout:
//delta[0*n_cells..1*n_cells-1] : input gate
//delta[1*n_cells..2*n_cells-1] : forget gate
//delta[2*n_cells..3*n_cells-1] : output gate
//delta[3*n_cells..4*n_cells-1] : cell state
//epsilon[0*n_cells..1*n_cells-1]: cell output derivative (later overwritten, see below)
//next_epsilon[0*n_cells..1*n_cells-1]: cell state derivative * forget_gate (of next timestep)
//repeated for every mini-batch
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_cells * n_batch) {
int batch_idx = idx / n_cells;
int batch_offset = batch_idx * 4 * n_cells;
int cell_offset = idx % n_cells;
int start = batch_offset + cell_offset;
float i_batch = i[batch_idx];
float inpGate = delta[start + n_cells];
float fgtGate = delta[start + 2 * n_cells];
float outGate = delta[start + 3 * n_cells];
float oldState = old_state_strided ? old_state[start] : old_state[idx];
float state = delta[start];
float eps = epsilon[idx];
//avoid division by 0
float gc = tanhf(state); //g(c(t))
float gzc = (state - fgtGate * oldState) / fmaxf(inpGate, float(1e-16)); //g(z_c(t))
//delta_output
delta[start + 3 * n_cells] = outGate * (1.f - outGate) * gc * eps * i_batch;
//epsilon_c
float epsilon_c = (1.f - (gc * gc)) * outGate * eps;
epsilon_c += next_epsilon[idx];
epsilon[idx] = epsilon_c * fgtGate * i_batch + next_epsilon[idx] * (1.f - i_batch);
//delta_cell
delta[start] = inpGate * (1.f - (gzc * gzc)) * epsilon_c * i_batch;
//delta_forget
delta[start + 2 * n_cells] = fgtGate * (1.f - fgtGate) * oldState * epsilon_c * i_batch;
//delta_input
delta[start + n_cells] = inpGate * (1.f - inpGate) * gzc * epsilon_c * i_batch;
idx += gridDim.x * blockDim.x;
}
}
"""
}
c_fw_code = """
// Z*, V_h, c, i = input_names (*: inplace)
// Y, H, d = output_names
assert(n_inputs == 4);
assert(n_outputs == 3);
Ndarray* V_h = inputs[1];
Ndarray* c = inputs[2];
Ndarray* i = inputs[3];
Ndarray* Y = *outputs[0];
Ndarray* H = *outputs[1]; // inplace on Z
Ndarray* d = *outputs[2];
long T = Ndarray_DIMS(i)[0];
int n_batch = Ndarray_DIMS(i)[1];
assert(Ndarray_DIMS(H)[2] %% 4 == 0); // 3 gates + cell
int n_cells = Ndarray_DIMS(H)[2] / 4;
assert(T > 0);
for(int x = 0; x < T; ++x) {
if(x > 0) {
//H += Y[x-1]*V_h
affine_y_x(x-1, Y, x, V_h, x, H);
}
start_dev_kernel(lstm_kernel, (
data_ptr(H, x),
x > 0 ? data_ptr(H, x - 1) : Ndarray_DEV_DATA(c),
x > 0,
data_ptr(Y, x),
(x == T - 1) ? Ndarray_DEV_DATA(d) : 0,
n_cells,
n_batch,
Ndarray_DEV_DATA(i) + x * n_batch
));
}
"""
c_bw_code = """
// V_h, c, i, Y, H*, DY*, Dd = input_names (*: inplace)
// DZ, DV_h, Dc, tmpDc = output_names
assert(n_inputs == 7);
assert(n_outputs == 4);
Ndarray* V_h = inputs[0];
Ndarray* c = inputs[1];
Ndarray* i = inputs[2];
Ndarray* Y = inputs[3];
Ndarray* Dd = inputs[6];
Ndarray* DZ = *outputs[0]; // inplace on H
Ndarray* DV_h = *outputs[1];
Ndarray* Dc = *outputs[2];
Ndarray* tmpDc = *outputs[3]; // (old DY), inplace buffer
long T = Ndarray_DIMS(i)[0];
int n_batch = Ndarray_DIMS(i)[1];
assert(Ndarray_DIMS(DZ)[2] %% 4 == 0); // 3 gates + cell
int n_cells = Ndarray_DIMS(DZ)[2] / 4;
assert(T > 0);
for(int x = T - 1; x >= 0; --x) {
// add recurrent
bool rightBorder = (x == T - 1);
if(!rightBorder)
affine_y_x(x+1, DZ, x, V_h, x, tmpDc, false, true);
start_dev_kernel(lstm_bwd_kernel, (
data_ptr(DZ, x),
data_ptr(tmpDc, x),
rightBorder ? Ndarray_DEV_DATA(Dd) : data_ptr(tmpDc, x + 1),
x > 0 ? data_ptr(DZ, x - 1) : Ndarray_DEV_DATA(c),
x > 0,
data_ptr(Y, x),
n_cells,
n_batch,
Ndarray_DEV_DATA(i) + x * n_batch
));
}
//DV_h = Y[0..end-1]^T * DZ[1..end]
affine_global(Y, DZ, DV_h, true, false, 1, 0.0f);
Ndarray_DIMS_Type Dc_dim = Ndarray_HOST_DIMS(Dc);
Ndarray_memcpy(
Ndarray_DEV_DATA(Dc), Ndarray_DEV_DATA(tmpDc),
Dc_dim[0] * Dc_dim[1] * sizeof(float));
"""
code_version = ()
class LstmLowMem(NativeOpGenBase):
"""
This is designed to require minimal memory during training.
It only stores the outputs and the cell states,
i.e. it requires time * cells * 2 floats for memory in total.
inputs:
:param X: (time,batch,in_dim)
:param W: forward+recurrent matrix. 2d (in_dim+dim,dim*4)
:param b: bias. 1d (dim*4,)
:param y0: initial output|hidden state. 2d (batch,dim)
:param c0: initial cell state. 2d (batch,dim)
:param i: index. 2d (time,batch) -> 0 or 1
:param start: where to start. must be >=0, default is usually 0. dtype int, scalar.
:param step: +1 for fwd, -1 for bwd direction. can also be |step|>1 for wider steps. dtype int, scalar.
for bwd (<0), will start at T-start-1.
outputs:
:param Y: output. 3d (time,batch,dim)
:param C: cell states. 3d (time,batch,dim). gradient ignored!
:param d: final cell state. 2d (batch,dim)
"""
in_info = (
{"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True},
{"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True},
{"name": "b", "ndim": 1, "shape": (None,), "need_contiguous": True},
{"name": "y0", "ndim": 2, "shape": (None, None), "need_contiguous": True},
{"name": "c0", "ndim": 2, "shape": (None, None), "need_contiguous": True},
{"name": "i", "ndim": 2, "shape": (None, None), "need_contiguous": True, "gradient": "disconnected"},
{"name": "start", "ndim": 0, "shape": (), "gradient": "disconnected", "dtype": "int32", "host_memory": True},
{"name": "step", "ndim": 0, "shape": (), "gradient": "disconnected", "dtype": "int32", "host_memory": True},
)
out_info = (
{"name": "Y", "ndim": 3, "shape": ((0, 0), (0, 1), (4, 1)), "need_contiguous": True},
{"name": "C", "ndim": 3, "shape": ((0, 0), (0, 1), (4, 1)), "need_contiguous": True},
{"name": "d", "ndim": 2, "shape": ((0, 1), (4, 1)), "need_contiguous": True}
)
@classmethod
def grad_input_map(cls, X, W, b, y0, c0, i, start, step, Y, C, d, DY, DC, Dd):
return (X, W, b, y0, c0, i, start, step, Y, C, DY, Dd)
c_extra_support_code = {
"lstm_kernel": """
DEF_KERNEL
void lstm_kernel(
int n_batch, int n_cells, const float* mask,
float* intern,
float* prev_c,
float* y,
float* c)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_cells * n_batch) {
int batch_idx = idx / n_cells;
int cell_idx = idx % n_cells;
int intern_offset = batch_idx * 4 * n_cells + cell_idx;
float prev_c_b = prev_c[idx];
float mask_b = mask[batch_idx];
// cell-in + input, forget and output gates
float cellIn = tanhf(intern[intern_offset]);
float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));
float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));
float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));
float c_b = (prev_c_b * fgtGate + cellIn * inpGate) * mask_b
+ prev_c_b * (1.f - mask_b);
c[idx] = c_b;
y[idx] = tanhf(c_b) * outGate * mask_b;
idx += gridDim.x * blockDim.x;
}
}
""",
"lstm_bwd_kernel": """
DEF_KERNEL
void lstm_bwd_kernel(
int n_batch, int n_in, int n_cells, const float* mask,
float* x_h,
float* intern,
float* prev_c,
float* y,
float* c,
float* d_y,
float* d_h,
float* d_c,
float* d_intern,
float* d_b)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_cells * n_batch) {
int batch_idx = idx / n_cells;
int cell_idx = idx % n_cells;
int intern_offset = batch_idx * 4 * n_cells + cell_idx;
float mask_b = mask[batch_idx];
float d_y_b = d_y[idx] * mask_b + d_h[idx];
float d_c_b = d_c[idx] * mask_b;
float prev_c_b = prev_c[idx];
// cell-in + input, forget and output gates
float cellIn = tanhf(intern[intern_offset]);
float inpGate = 1.f / (1.f + expf(-intern[intern_offset + n_cells]));
float fgtGate = 1.f / (1.f + expf(-intern[intern_offset + 2 * n_cells]));
float outGate = 1.f / (1.f + expf(-intern[intern_offset + 3 * n_cells]));
float c_b = prev_c_b * fgtGate + cellIn * inpGate;
float gc = tanhf(c_b);
float d_outGate_in = (1.f - outGate) * outGate * gc * d_y_b;
float d_c2 = d_c_b + outGate * d_y_b * (1.f - gc * gc);
float d_cellIn_in = (1.f - cellIn * cellIn) * inpGate * d_c2;
float d_inpGate_in = (1.f - inpGate) * inpGate * cellIn * d_c2;
float d_fgtGate_in = (1.f - fgtGate) * fgtGate * prev_c_b * d_c2;
d_c[idx] = fgtGate * d_c2 + d_c[idx] * (1.f - mask_b);
d_intern[intern_offset] = d_cellIn_in;
d_intern[intern_offset + n_cells] = d_inpGate_in;
d_intern[intern_offset + 2 * n_cells] = d_fgtGate_in;
d_intern[intern_offset + 3 * n_cells] = d_outGate_in;
elem_atomic_add(&d_b[cell_idx], d_cellIn_in);
elem_atomic_add(&d_b[cell_idx + n_cells], d_inpGate_in);
elem_atomic_add(&d_b[cell_idx + 2 * n_cells], d_fgtGate_in);
elem_atomic_add(&d_b[cell_idx + 3 * n_cells], d_outGate_in);
idx += gridDim.x * blockDim.x;
}
}
""",
"add_bias_kernel": """
DEF_KERNEL
void add_bias_kernel(int n_batch, int n_dim, float* x, float* b) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_batch * n_dim) {
int dim_idx = idx % n_dim;
x[idx] += b[dim_idx];
idx += gridDim.x * blockDim.x;
}
}
""",
"copy_x_h_kernel": """
DEF_KERNEL
void copy_x_h_kernel(
int n_batch, int n_in, int n_cells,
float* x_h,
float* x,
float* h)
{
int n_total_in = n_in + n_cells;
int idx = threadIdx.x + blockDim.x * blockIdx.x;
while (idx < n_batch * n_total_in) {
int batch_idx = idx / n_total_in;
int in_dim_idx = idx % n_total_in;
if(in_dim_idx < n_in)
x_h[idx] = x[batch_idx * n_in + in_dim_idx];
else
x_h[idx] = h[batch_idx * n_cells + in_dim_idx - n_in];
idx += gridDim.x * blockDim.x;
}
}
""",
"inv_copy_x_h_kernel": """
DEF_KERNEL
void inv_copy_x_h_kernel(
int n_batch, int n_in, int n_cells,
float* x_h,
float* x,