Skip to content

Commit

Permalink
add gate layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jun 24, 2023
1 parent d1d16ac commit 0c85dd2
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .blocks import MLP
from .blocks import Gate
from .blocks import Highway
from .bst import BST
from .din import DIN
Expand Down
22 changes: 22 additions & 0 deletions easy_rec/python/layers/keras/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,25 @@ def call(self, inputs, training=None, **kwargs):
activation=self.activation,
num_layers=self.num_layers,
dropout=self.dropout_rate if training else 0.0)


class Gate(tf.keras.layers.Layer):
"""Weighted sum gate."""

def __init__(self, params, name='gate', **kwargs):
super(Gate, self).__init__(name, **kwargs)
self.weight_index = params.get_or_default("weight_index", 0)

def call(self, inputs, **kwargs):
assert len(inputs) > 1, 'input of Gate layer must be a list containing at least 2 elements'
weights = inputs[self.weight_index]
j = 0
for i, x in enumerate(inputs):
if i == self.weight_index:
continue
if j == 0:
output = weights[:, j] * x
else:
output += weights[:, j] * x
j += 1
return output
39 changes: 38 additions & 1 deletion easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,24 @@


class MaskBlock(tf.keras.layers.Layer):
"""MaskBlock use in MaskNet.
Args:
projection_dim: project dimension to reduce the computational cost.
Default is `None` such that a full (`input_dim` by `aggregation_size`) matrix
W is used. If enabled, a low-rank matrix W = U*V will be used, where U
is of size `input_dim` by `projection_dim` and V is of size
`projection_dim` by `aggregation_size`. `projection_dim` need to be smaller
than `aggregation_size`/2 to improve the model efficiency. In practice, we've
observed that `projection_dim` = d/4 consistently preserved the
accuracy of a full-rank version.
"""

def __init__(self, params, name='mask_block', reuse=None, **kwargs):
super(MaskBlock, self).__init__(name, **kwargs)
self.config = params.get_pb_config()
self.l2_reg = params.l2_regularizer
self._projection_dim = params.get_or_default('projection_dim', None)
self.reuse = reuse

def call(self, inputs, **kwargs):
Expand All @@ -31,13 +45,33 @@ def call(self, inputs, **kwargs):

# initializer = tf.initializers.variance_scaling()
initializer = tf.glorot_uniform_initializer()
mask = tf.layers.dense(

if self._projection_dim is None:
mask = tf.layers.dense(
mask_input,
aggregation_size,
activation=tf.nn.relu,
kernel_initializer=initializer,
kernel_regularizer=self.l2_reg,
name='%s/hidden' % self.name,
reuse=self.reuse)
else:
u = tf.layers.dense(
mask_input,
self._projection_dim,
kernel_initializer=initializer,
kernel_regularizer=self.l2_reg,
use_bias=False,
name='%s/prj_u' % self.name,
reuse=self.reuse)
mask = tf.layers.dense(
u,
aggregation_size,
activation=tf.nn.relu,
kernel_initializer=initializer,
kernel_regularizer=self.l2_reg,
name='%s/prj_v' % self.name,
reuse=self.reuse)
mask = tf.layers.dense(
mask, net.shape[-1], name='%s/mask' % self.name, reuse=self.reuse)
masked_net = net * mask
Expand All @@ -62,6 +96,7 @@ class MaskNet(tf.keras.layers.Layer):

def __init__(self, params, name='mask_net', **kwargs):
super(MaskNet, self).__init__(name, **kwargs)
self.params = params
self.config = params.get_pb_config()
if self.config.HasField('mlp'):
p = Parameter.make_from_pb(self.config.mlp)
Expand All @@ -75,6 +110,7 @@ def call(self, inputs, training=None, **kwargs):
mask_outputs = []
for i, block_conf in enumerate(self.config.mask_blocks):
params = Parameter.make_from_pb(block_conf)
params.l2_regularizer = self.params.l2_regularizer
mask_layer = MaskBlock(params, name='%s/block_%d' % (self.name, i))
mask_outputs.append(mask_layer((inputs, inputs)))
all_mask_outputs = tf.concat(mask_outputs, axis=1)
Expand All @@ -88,6 +124,7 @@ def call(self, inputs, training=None, **kwargs):
net = inputs
for i, block_conf in enumerate(self.config.mask_blocks):
params = Parameter.make_from_pb(block_conf)
params.l2_regularizer = self.params.l2_regularizer
mask_layer = MaskBlock(params, name='%s/block_%d' % (self.name, i))
net = mask_layer((net, inputs))

Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def get_or_default(self, key, def_val):
return value
return def_val
else: # pb message
return getattr(self.params, key)
if self.params.HasField(key):
return getattr(self.params, key)
return def_val

def check_required(self, keys):
if not self.is_struct:
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ message MaskBlock {
required uint32 output_size = 2;
optional uint32 aggregation_size = 3;
optional bool input_layer_norm = 4 [default = true];
optional uint32 projection_dim = 5;
}

message MaskNet {
Expand Down

0 comments on commit 0c85dd2

Please sign in to comment.