-
Notifications
You must be signed in to change notification settings - Fork 40
/
DJCLayers.py
163 lines (104 loc) · 4.57 KB
/
DJCLayers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
djc_global_layers_list={}
from keras.layers import Layer
import tensorflow as tf
class StopGradient(Layer):
def __init__(self, **kwargs):
super(StopGradient, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return tf.stop_gradient(inputs)
djc_global_layers_list['StopGradient']=StopGradient
class SelectFeatures(Layer):
def __init__(self, index_left, index_right, **kwargs):
super(SelectFeatures, self).__init__(**kwargs)
self.index_left=index_left
self.index_right=index_right
def compute_output_shape(self, input_shape):
return input_shape[:-1] + (self.index_right-self.index_left,)
def call(self, inputs):
return inputs[...,self.index_left:self.index_right]
def get_config(self):
config = {'index_left': self.index_left,'index_right': self.index_right}
base_config = super(SelectFeatures, self).get_config()
return dict(list(base_config.items()) + list(config.items() ))
djc_global_layers_list['SelectFeatures']=SelectFeatures
class ScalarMultiply(Layer):
def __init__(self, factor, **kwargs):
super(ScalarMultiply, self).__init__(**kwargs)
self.factor=factor
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return inputs*self.factor
def get_config(self):
config = {'factor': self.factor}
base_config = super(ScalarMultiply, self).get_config()
return dict(list(base_config.items()) + list(config.items() ))
djc_global_layers_list['ScalarMultiply']=ScalarMultiply
class Print(Layer):
def __init__(self, message, **kwargs):
super(Print, self).__init__(**kwargs)
self.message=message
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return tf.Print(inputs,[inputs],self.message,summarize=300)
def get_config(self):
config = {'message': self.message}
base_config = super(Print, self).get_config()
return dict(list(base_config.items()) + list(config.items() ))
djc_global_layers_list['Print']=Print
### the following ones should go to DeepJetCore
class ReplaceByNoise(Layer):
def __init__(self, **kwargs):
super(ReplaceByNoise, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return tf.random_normal(shape=tf.shape(inputs),
mean=0.0,
stddev=1.0,
dtype='float32')
def get_config(self):
base_config = super(ReplaceByNoise, self).get_config()
return dict(list(base_config.items()))
djc_global_layers_list['ReplaceByNoise']=ReplaceByNoise
class FeedForward(Layer):
def __init__(self, **kwargs):
super(FeedForward, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return 1.*inputs
def get_config(self):
base_config = super(FeedForward, self).get_config()
return dict(list(base_config.items()))
djc_global_layers_list['FeedForward']=FeedForward
class Clip(Layer):
def __init__(self, min, max , **kwargs):
super(Clip, self).__init__(**kwargs)
self.min=min
self.max=max
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
return tf.clip_by_value(inputs, self.min, self.max)
def get_config(self):
config = {'min': self.min, 'max': self.max}
base_config = super(Clip, self).get_config()
return dict(list(base_config.items()) + list(config.items() ))
djc_global_layers_list['Clip']=Clip
class ReduceSumEntirely(Layer):
def __init__(self, **kwargs):
super(ReduceSumEntirely, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
return (input_shape[0],1)
def call(self, inputs):
red_axes=(inputs.shape[1:]).as_list()
red_axes = [i+1 for i in range(len(red_axes))]
return tf.expand_dims(tf.reduce_sum(inputs,axis=red_axes),axis=1)
def get_config(self):
base_config = super(ReduceSumEntirely, self).get_config()
return dict(list(base_config.items()))
djc_global_layers_list['ReduceSumEntirely']=ReduceSumEntirely