From 58ef6ad95a480615891ec911d32c89a6bfa8a56e Mon Sep 17 00:00:00 2001 From: wwxxzz Date: Thu, 21 Jul 2022 15:19:33 +0800 Subject: [PATCH] add graph-based config and model-- ultragcn --- easy_rec/python/input/graph_input.py | 601 ++++++++++++++++++ easy_rec/python/model/ultragcn.py | 122 ++++ easy_rec/python/protos/data_source.proto | 8 + easy_rec/python/protos/dataset.proto | 26 +- easy_rec/python/protos/easy_rec_model.proto | 2 + easy_rec/python/protos/pipeline.proto | 2 + easy_rec/python/protos/ultragcn.proto | 14 + easy_rec/python/utils/graph_utils.py | 55 ++ pai_jobs/run.py | 15 +- samples/model_config/ultragcn_on_graph.config | 82 +++ 10 files changed, 922 insertions(+), 5 deletions(-) create mode 100644 easy_rec/python/input/graph_input.py create mode 100644 easy_rec/python/model/ultragcn.py create mode 100644 easy_rec/python/protos/ultragcn.proto create mode 100644 easy_rec/python/utils/graph_utils.py create mode 100644 samples/model_config/ultragcn_on_graph.config diff --git a/easy_rec/python/input/graph_input.py b/easy_rec/python/input/graph_input.py new file mode 100644 index 000000000..ad0b9c14c --- /dev/null +++ b/easy_rec/python/input/graph_input.py @@ -0,0 +1,601 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging +import os +import sys + +import numpy as np +import tensorflow as tf + +from easy_rec.python.core import sampler +from easy_rec.python.input.input import Input +from easy_rec.python.utils import graph_utils + +try: + import graphlearn as gl +except ImportError: + logging.error( + 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl"' # noqa: E501 + ) + sys.exit(1) + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +class GraphInput(Input): + _g = None + def __init__(self, + data_config, + feature_configs, + input_path, + task_index=0, + task_num=1, + check_mode=False): + super(GraphInput, self).__init__(data_config, feature_configs, input_path, + task_index, task_num, check_mode) + + gl.set_shuffle_buffer_size(102400000) + if input_path: + if GraphInput._g is None: + if self._data_config.HasField('ultra_gcn_sampler'): + GraphInput._g = gl.Graph()\ + .node(tf.compat.as_str(input_path.user_node_input), node_type="u", + decoder=gl.Decoder(attr_types=['int']))\ + .node(tf.compat.as_str(input_path.item_node_input), node_type="i", + decoder=gl.Decoder(attr_types=['int']))\ + .edge(tf.compat.as_str(input_path.u2i_edge_input), edge_type=("u", "i", "u-i"), + decoder=gl.Decoder(weighted=False), directed=False)\ + .edge(tf.compat.as_str(input_path.i2i_edge_input), edge_type=("i", "i", "i-i"), + decoder=gl.Decoder(weighted=True), directed=True) + graph_utils.graph_init(GraphInput._g,os.environ.get('TF_CONFIG', None)) + + def _sample_generator_ultragcn(self): + + def ultragcn_sampler(): + epoch_id = 0 + while self.num_epochs is None or epoch_id < self.num_epochs: + if self._mode == tf.estimator.ModeKeys.TRAIN: + self.edge_sampler = GraphInput._g.edge_sampler("u-i", self._batch_size, strategy="shuffle") + self.u_sampler = GraphInput._g.node_sampler("u", self._batch_size, strategy="by_order") + self.i_sampler = GraphInput._g.node_sampler("i", self._batch_size, strategy="by_order") + self.i2i_nbr_sampler = GraphInput._g.neighbor_sampler("i-i", self._nbr_num, strategy="topk") + self.neg_sampler = GraphInput._g.negative_sampler("u-i", self._neg_num, "random") + else: + self.edge_sampler = GraphInput._g.edge_sampler("u-i", self._batch_size, strategy="shuffle") + self.neg_sampler = GraphInput._g.negative_sampler("u-i", self._neg_num, "random") + self.u_sampler = GraphInput._g.node_sampler("u", self._batch_size, strategy="by_order") + self.i_sampler = GraphInput._g.node_sampler("i", self._batch_size, strategy="by_order") + self.i2i_nbr_sampler = GraphInput._g.neighbor_sampler("i-i", self._nbr_num, strategy="topk") + + while True: + try: + samples=[] + edges = self.edge_sampler.get() + neg_items = self.neg_sampler.get(edges.src_ids) + nbr_items = self.i2i_nbr_sampler.get(edges.dst_ids) + samples.append(edges.src_ids) # user ids + samples.append(self._g.out_degrees(edges.src_ids, 'u-i')) # user degrees + samples.append(edges.dst_ids) # item ids + samples.append(self._g.out_degrees(edges.dst_ids, 'u-i_reverse')) # item degrees + samples.append(nbr_items.layer_nodes(1).ids) # nbr item ids + samples.append(nbr_items.layer_edges(1).weights) # nbr item weight. + samples.append(neg_items.ids) # neg item ids + yield(tuple(samples)) + except gl.OutOfRangeError: + break + if self._mode != tf.estimator.ModeKeys.TRAIN: + break + epoch_id += 1 + + self._nbr_num = self._data_config.ultra_gcn_sampler.nbr_num + self._neg_num = self._data_config.ultra_gcn_sampler.neg_num + output_types = [tf.int64, tf.float32, tf.int64, tf.float32, + tf.int64, tf.float32, tf.int64] + # user ids, user degrees, item ids, item degrees, nbr item ids, nbr item weight, neg item ids + output_shapes = [tf.TensorShape([None]), + tf.TensorShape([None]), + tf.TensorShape([None]), + tf.TensorShape([None]), + tf.TensorShape([None, self._nbr_num]), + tf.TensorShape([None, self._nbr_num]), + tf.TensorShape([None, self._neg_num])] + + dataset = tf.data.Dataset.from_generator( + ultragcn_sampler, + output_types=tuple(output_types), + output_shapes=tuple(output_shapes)) + return dataset + + def _to_fea_dict(self, *features): + fea_dict_= {'features': []} + fea_dict_['features'] = features + return fea_dict_ + + def _get_features(self, field_dict_groups): + return { + 'features': field_dict_groups['features'] + } + + def _get_labels(self, field_dict): + return { + } + + def _preprocess(self, field_dict): + """Preprocess the feature columns. + + preprocess some feature columns, such as TagFeature or LookupFeature, + it is expected to handle batch inputs and single input, + it could be customized in subclasses + + Args: + field_dict: string to tensor, tensors are dense, + could be of shape [batch_size], [batch_size, None], or of shape [] + + Returns: + output_dict: some of the tensors are transformed into sparse tensors, + such as input tensors of tag features and lookup features + """ + parsed_dict = {} + + if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT: + if self._mode != tf.estimator.ModeKeys.TRAIN: + self._sampler.set_eval_num_sample() + sampler_type = self._data_config.WhichOneof('sampler') + sampler_config = getattr(self._data_config, sampler_type) + item_ids = field_dict[sampler_config.item_id_field] + if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']: + sampled = self._sampler.get(item_ids) + elif sampler_type == 'negative_sampler_v2': + user_ids = field_dict[sampler_config.user_id_field] + sampled = self._sampler.get(user_ids, item_ids) + elif sampler_type.startswith('hard_negative_sampler'): + user_ids = field_dict[sampler_config.user_id_field] + sampled = self._sampler.get(user_ids, item_ids) + else: + raise ValueError('Unknown sampler %s' % sampler_type) + for k, v in sampled.items(): + if k in field_dict: + field_dict[k] = tf.concat([field_dict[k], v], axis=0) + else: + print('appended fields: %s' % k) + parsed_dict[k] = v + self._appended_fields.append(k) + + for fc in self._feature_configs: + feature_name = fc.feature_name + feature_type = fc.feature_type + input_0 = fc.input_names[0] + if feature_type == fc.TagFeature: + input_0 = fc.input_names[0] + field = field_dict[input_0] + # Construct the output of TagFeature according to the dimension of field_dict. + # When the input field exceeds 2 dimensions, convert TagFeature to 2D output. + if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1: + if len(field.get_shape()) == 0: + field = tf.expand_dims(field, axis=0) + elif len(field.get_shape()) == 2: + field = tf.squeeze(field, axis=-1) + if fc.HasField('kv_separator') and len(fc.input_names) > 1: + assert False, 'Tag Feature Error, ' \ + 'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0 + parsed_dict[input_0] = tf.string_split(field, fc.separator) + if fc.HasField('kv_separator'): + indices = parsed_dict[input_0].indices + tmp_kvs = parsed_dict[input_0].values + tmp_kvs = tf.string_split( + tmp_kvs, fc.kv_separator, skip_empty=False) + tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2]) + tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1] + + check_list = [ + tf.py_func( + check_string_to_number, [tmp_vs, input_0], Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + tmp_vs = tf.string_to_number( + tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0) + parsed_dict[input_0] = tf.sparse.SparseTensor( + indices, tmp_ks, parsed_dict[input_0].dense_shape) + input_wgt = input_0 + '_WEIGHT' + parsed_dict[input_wgt] = tf.sparse.SparseTensor( + indices, tmp_vs, parsed_dict[input_0].dense_shape) + self._appended_fields.append(input_wgt) + if not fc.HasField('hash_bucket_size'): + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + vals = tf.string_to_number( + parsed_dict[input_0].values, + tf.int32, + name='tag_fea_%s' % input_0) + parsed_dict[input_0] = tf.sparse.SparseTensor( + parsed_dict[input_0].indices, vals, + parsed_dict[input_0].dense_shape) + if len(fc.input_names) > 1: + input_1 = fc.input_names[1] + field = field_dict[input_1] + if len(field.get_shape()) == 0: + field = tf.expand_dims(field, axis=0) + field = tf.string_split(field, fc.separator) + check_list = [ + tf.py_func( + check_string_to_number, [field.values, input_1], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + field_vals = tf.string_to_number( + field.values, + tf.float32, + name='tag_wgt_str_2_flt_%s' % input_1) + assert_op = tf.assert_equal( + tf.shape(field_vals)[0], + tf.shape(parsed_dict[input_0].values)[0], + message='TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s.' + % (input_0, input_1, input_0, input_1)) + with tf.control_dependencies([assert_op]): + field = tf.sparse.SparseTensor(field.indices, + tf.identity(field_vals), + field.dense_shape) + parsed_dict[input_1] = field + else: + parsed_dict[input_0] = field_dict[input_0] + if len(fc.input_names) > 1: + input_1 = fc.input_names[1] + parsed_dict[input_1] = field_dict[input_1] + elif feature_type == fc.LookupFeature: + assert feature_name is not None and feature_name != '' + assert len(fc.input_names) == 2 + parsed_dict[feature_name] = self._lookup_preprocess(fc, field_dict) + elif feature_type == fc.SequenceFeature: + input_0 = fc.input_names[0] + field = field_dict[input_0] + sub_feature_type = fc.sub_feature_type + # Construct the output of SeqFeature according to the dimension of field_dict. + # When the input field exceeds 2 dimensions, convert SeqFeature to 2D output. + if len(field.get_shape()) < 2: + parsed_dict[input_0] = tf.strings.split(field, fc.separator) + if fc.HasField('seq_multi_sep'): + indices = parsed_dict[input_0].indices + values = parsed_dict[input_0].values + multi_vals = tf.string_split(values, fc.seq_multi_sep) + indices_1 = multi_vals.indices + indices = tf.gather(indices, indices_1[:, 0]) + out_indices = tf.concat([indices, indices_1[:, 1:]], axis=1) + # 3 dimensional sparse tensor + out_shape = tf.concat( + [parsed_dict[input_0].dense_shape, multi_vals.dense_shape[1:]], + axis=0) + parsed_dict[input_0] = tf.sparse.SparseTensor( + out_indices, multi_vals.values, out_shape) + if (fc.num_buckets > 1 and fc.max_val == fc.min_val): + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_0] = tf.sparse.SparseTensor( + parsed_dict[input_0].indices, + tf.string_to_number( + parsed_dict[input_0].values, + tf.int64, + name='sequence_str_2_int_%s' % input_0), + parsed_dict[input_0].dense_shape) + elif sub_feature_type == fc.RawFeature: + check_list = [ + tf.py_func( + check_string_to_number, + [parsed_dict[input_0].values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_0] = tf.sparse.SparseTensor( + parsed_dict[input_0].indices, + tf.string_to_number( + parsed_dict[input_0].values, + tf.float32, + name='sequence_str_2_float_%s' % input_0), + parsed_dict[input_0].dense_shape) + if fc.num_buckets > 1 and fc.max_val > fc.min_val: + normalized_values = (parsed_dict[input_0].values - fc.min_val) / ( + fc.max_val - fc.min_val) + parsed_dict[input_0] = tf.sparse.SparseTensor( + parsed_dict[input_0].indices, normalized_values, + parsed_dict[input_0].dense_shape) + else: + parsed_dict[input_0] = field + if not fc.boundaries and fc.num_buckets <= 1 and fc.hash_bucket_size <= 0 and \ + self._data_config.sample_weight != input_0 and sub_feature_type == fc.RawFeature and \ + fc.raw_input_dim == 1: + # may need by wide model and deep model to project + # raw values to a vector, it maybe better implemented + # by a ProjectionColumn later + logging.info( + 'Not set boundaries or num_buckets or hash_bucket_size, %s will process as two dimension raw feature' + % input_0) + parsed_dict[input_0] = tf.sparse_to_dense( + parsed_dict[input_0].indices, + [tf.shape(parsed_dict[input_0])[0], fc.sequence_length], + parsed_dict[input_0].values) + sample_num = tf.to_int64(tf.shape(parsed_dict[input_0])[0]) + indices_0 = tf.range(sample_num, dtype=tf.int64) + indices_1 = tf.range(fc.sequence_length, dtype=tf.int64) + indices_0 = indices_0[:, None] + indices_1 = indices_1[None, :] + indices_0 = tf.tile(indices_0, [1, fc.sequence_length]) + indices_1 = tf.tile(indices_1, [sample_num, 1]) + indices_0 = tf.reshape(indices_0, [-1, 1]) + indices_1 = tf.reshape(indices_1, [-1, 1]) + indices = tf.concat([indices_0, indices_1], axis=1) + parsed_dict[input_0 + '_raw_proj_id'] = tf.SparseTensor( + indices=indices, + values=indices_1[:, 0], + dense_shape=[sample_num, fc.sequence_length]) + parsed_dict[input_0 + '_raw_proj_val'] = tf.SparseTensor( + indices=indices, + values=tf.reshape(parsed_dict[input_0], [-1]), + dense_shape=[sample_num, fc.sequence_length]) + self._appended_fields.append(input_0 + '_raw_proj_id') + self._appended_fields.append(input_0 + '_raw_proj_val') + elif not fc.boundaries and fc.num_buckets <= 1 and fc.hash_bucket_size <= 0 and \ + self._data_config.sample_weight != input_0 and sub_feature_type == fc.RawFeature and \ + fc.raw_input_dim > 1: + # for 3 dimension sequence feature input. + # may need by wide model and deep model to project + # raw values to a vector, it maybe better implemented + # by a ProjectionColumn later + logging.info( + 'Not set boundaries or num_buckets or hash_bucket_size, %s will process as three dimension raw feature' + % input_0) + parsed_dict[input_0] = tf.sparse_to_dense( + parsed_dict[input_0].indices, [ + tf.shape(parsed_dict[input_0])[0], fc.sequence_length, + fc.raw_input_dim + ], parsed_dict[input_0].values) + sample_num = tf.to_int64(tf.shape(parsed_dict[input_0])[0]) + indices_0 = tf.range(sample_num, dtype=tf.int64) + indices_1 = tf.range(fc.sequence_length, dtype=tf.int64) + indices_2 = tf.range(fc.raw_input_dim, dtype=tf.int64) + indices_0 = indices_0[:, None, None] + indices_1 = indices_1[None, :, None] + indices_2 = indices_2[None, None, :] + indices_0 = tf.tile(indices_0, + [1, fc.sequence_length, fc.raw_input_dim]) + indices_1 = tf.tile(indices_1, [sample_num, 1, fc.raw_input_dim]) + indices_2 = tf.tile(indices_2, [sample_num, fc.sequence_length, 1]) + indices_0 = tf.reshape(indices_0, [-1, 1]) + indices_1 = tf.reshape(indices_1, [-1, 1]) + indices_2 = tf.reshape(indices_2, [-1, 1]) + indices = tf.concat([indices_0, indices_1, indices_2], axis=1) + + parsed_dict[input_0 + '_raw_proj_id'] = tf.SparseTensor( + indices=indices, + values=indices_1[:, 0], + dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim]) + parsed_dict[input_0 + '_raw_proj_val'] = tf.SparseTensor( + indices=indices, + values=tf.reshape(parsed_dict[input_0], [-1]), + dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim]) + self._appended_fields.append(input_0 + '_raw_proj_id') + self._appended_fields.append(input_0 + '_raw_proj_val') + elif feature_type == fc.RawFeature: + input_0 = fc.input_names[0] + if field_dict[input_0].dtype == tf.string: + if fc.raw_input_dim > 1: + check_list = [ + tf.py_func( + check_split, [ + field_dict[input_0], fc.separator, fc.raw_input_dim, + input_0 + ], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + tmp_fea = tf.string_split(field_dict[input_0], fc.separator) + check_list = [ + tf.py_func( + check_string_to_number, [tmp_fea.values, input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + tmp_vals = tf.string_to_number( + tmp_fea.values, + tf.float32, + name='multi_raw_fea_to_flt_%s' % input_0) + parsed_dict[input_0] = tf.sparse_to_dense( + tmp_fea.indices, + [tf.shape(field_dict[input_0])[0], fc.raw_input_dim], + tmp_vals, + default_value=0) + else: + check_list = [ + tf.py_func( + check_string_to_number, [field_dict[input_0], input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_0] = tf.string_to_number( + field_dict[input_0], tf.float32) + elif field_dict[input_0].dtype in [ + tf.int32, tf.int64, tf.double, tf.float32 + ]: + parsed_dict[input_0] = tf.to_float(field_dict[input_0]) + else: + assert False, 'invalid dtype[%s] for raw feature' % str( + field_dict[input_0].dtype) + if fc.max_val > fc.min_val: + parsed_dict[input_0] = (parsed_dict[input_0] - fc.min_val) /\ + (fc.max_val - fc.min_val) + if not fc.boundaries and fc.num_buckets <= 1 and \ + self._data_config.sample_weight != input_0: + # may need by wide model and deep model to project + # raw values to a vector, it maybe better implemented + # by a ProjectionColumn later + sample_num = tf.to_int64(tf.shape(parsed_dict[input_0])[0]) + indices_0 = tf.range(sample_num, dtype=tf.int64) + indices_1 = tf.range(fc.raw_input_dim, dtype=tf.int64) + indices_0 = indices_0[:, None] + indices_1 = indices_1[None, :] + indices_0 = tf.tile(indices_0, [1, fc.raw_input_dim]) + indices_1 = tf.tile(indices_1, [sample_num, 1]) + indices_0 = tf.reshape(indices_0, [-1, 1]) + indices_1 = tf.reshape(indices_1, [-1, 1]) + indices = tf.concat([indices_0, indices_1], axis=1) + + parsed_dict[input_0 + '_raw_proj_id'] = tf.SparseTensor( + indices=indices, + values=indices_1[:, 0], + dense_shape=[sample_num, fc.raw_input_dim]) + parsed_dict[input_0 + '_raw_proj_val'] = tf.SparseTensor( + indices=indices, + values=tf.reshape(parsed_dict[input_0], [-1]), + dense_shape=[sample_num, fc.raw_input_dim]) + self._appended_fields.append(input_0 + '_raw_proj_id') + self._appended_fields.append(input_0 + '_raw_proj_val') + elif feature_type == fc.IdFeature: + input_0 = fc.input_names[0] + parsed_dict[input_0] = field_dict[input_0] + if fc.HasField('hash_bucket_size'): + if field_dict[input_0].dtype != tf.string: + if field_dict[input_0].dtype in [tf.float32, tf.double]: + assert fc.precision > 0, 'it is dangerous to convert float or double to string due to ' \ + 'precision problem, it is suggested to convert them into string ' \ + 'format during feature generalization before using EasyRec; ' \ + 'if you really need to do so, please set precision (the number of ' \ + 'decimal digits) carefully.' + precision = None + if field_dict[input_0].dtype in [tf.float32, tf.double]: + if fc.precision > 0: + precision = fc.precision + # convert to string + if 'as_string' in dir(tf.strings): + parsed_dict[input_0] = tf.strings.as_string( + field_dict[input_0], precision=precision) + else: + parsed_dict[input_0] = tf.as_string( + field_dict[input_0], precision=precision) + elif fc.num_buckets > 0: + if parsed_dict[input_0].dtype == tf.string: + check_list = [ + tf.py_func( + check_string_to_number, [parsed_dict[input_0], input_0], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_0] = tf.string_to_number( + parsed_dict[input_0], tf.int32, name='%s_str_2_int' % input_0) + elif feature_type == fc.ExprFeature: + fea_name = fc.feature_name + prefix = 'expr_' + for input_name in fc.input_names: + new_input_name = prefix + input_name + if field_dict[input_name].dtype == tf.string: + check_list = [ + tf.py_func( + check_string_to_number, + [field_dict[input_name], input_name], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[new_input_name] = tf.string_to_number( + field_dict[input_name], + tf.float64, + name='%s_str_2_int_for_expr' % new_input_name) + elif field_dict[input_name].dtype in [ + tf.int32, tf.int64, tf.double, tf.float32 + ]: + parsed_dict[new_input_name] = tf.cast(field_dict[input_name], + tf.float64) + else: + assert False, 'invalid input dtype[%s] for expr feature' % str( + field_dict[input_name].dtype) + + expression = get_expression( + fc.expression, fc.input_names, prefix=prefix) + logging.info('expression: %s' % expression) + parsed_dict[fea_name] = eval(expression) + self._appended_fields.append(fea_name) + else: + for input_name in fc.input_names: + parsed_dict[input_name] = field_dict[input_name] + + for input_id, input_name in enumerate(self._label_fields): + if input_name not in field_dict: + continue + if field_dict[input_name].dtype == tf.string: + if self._label_dim[input_id] > 1: + logging.info('will split labels[%d]=%s' % (input_id, input_name)) + check_list = [ + tf.py_func( + check_split, [ + field_dict[input_name], self._label_sep[input_id], + self._label_dim[input_id], input_name + ], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_name] = tf.string_split( + field_dict[input_name], self._label_sep[input_id]).values + parsed_dict[input_name] = tf.reshape( + parsed_dict[input_name], [-1, self._label_dim[input_id]]) + else: + parsed_dict[input_name] = field_dict[input_name] + check_list = [ + tf.py_func( + check_string_to_number, [parsed_dict[input_name], input_name], + Tout=tf.bool) + ] if self._check_mode else [] + with tf.control_dependencies(check_list): + parsed_dict[input_name] = tf.string_to_number( + parsed_dict[input_name], tf.float32, name=input_name) + else: + assert field_dict[input_name].dtype in [ + tf.float32, tf.double, tf.int32, tf.int64 + ], 'invalid label dtype: %s' % str(field_dict[input_name].dtype) + parsed_dict[input_name] = field_dict[input_name] + + if self._data_config.HasField('sample_weight'): + if self._mode != tf.estimator.ModeKeys.PREDICT: + parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[ + self._data_config.sample_weight] + + if self._data_config.input_type == 19: + parsed_dict = {} + for fd in field_dict: + parsed_dict[fd] = field_dict[fd] + + return parsed_dict + + def _build(self, mode, params): + """Build graph dataset input for estimator. + + Args: + mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT) + params: `dict` of hyper parameters, from Estimator + + Return: + dataset: dataset for graph models. + """ + if self._data_config.HasField('ultra_gcn_sampler'): + dataset = self._sample_generator_ultragcn() + + # transform list to feature dict + dataset = dataset.map(map_func=self._to_fea_dict) + + dataset = dataset.prefetch(buffer_size=self._prefetch_size) + + if mode != tf.estimator.ModeKeys.PREDICT: + dataset = dataset.map(lambda x: + (self._get_features(x), self._get_labels(x))) + else: + dataset = dataset.map(lambda x: (self._get_features(x))) + + return dataset + + def __del__(self): + GraphInput._g.close() diff --git a/easy_rec/python/model/ultragcn.py b/easy_rec/python/model/ultragcn.py new file mode 100644 index 000000000..f3c0d4f3b --- /dev/null +++ b/easy_rec/python/model/ultragcn.py @@ -0,0 +1,122 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +import tensorflow as tf + +from easy_rec.python.layers import dnn +from easy_rec.python.model.easy_rec_model import EasyRecModel + +from easy_rec.python.protos.ultragcn_pb2 import ULTRAGCN as ULTRAGCNConfig # NOQA + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class ULTRAGCN(EasyRecModel): + + def __init__(self, + model_config, + feature_configs, + features, + labels=None, + is_training=False): + super(ULTRAGCN, self).__init__(model_config, feature_configs, features, labels, + is_training) + self._model_config = model_config.ultragcn + assert isinstance(self._model_config, ULTRAGCNConfig) + self._user_num = self._model_config.user_num + self._item_num = self._model_config.item_num + self._emb_dim = self._model_config.output_dim + self._i2i_weight = self._model_config.i2i_weight + self._neg_weight = self._model_config.neg_weight + self._l2_weight = self._model_config.l2_weight + self._user_emb = None + self._item_emb = None + + if features.get('features') is not None: + self._user_ids = features.get('features')[0] + self._user_degrees = features.get('features')[1] + self._item_ids = features.get('features')[2] + self._item_degrees = features.get('features')[3] + self._nbr_ids = features.get('features')[4] + self._nbr_weights = features.get('features')[5] + self._neg_ids = features.get('features')[6] + else: + self._user_ids = features.get('id') + self._user_degrees = None + self._item_ids = features.get('id') + self._item_degrees = None + self._nbr_ids = features.get('id') + self._nbr_weights = None + self._neg_ids = features.get('id') + + _user_emb = tf.get_variable("user_emb", + [self._user_num, self._emb_dim], + trainable=True) + _item_emb = tf.get_variable("item_emb", + [self._item_num, self._emb_dim], + trainable=True) + + self._user_emb = tf.convert_to_tensor(_user_emb) + self._item_emb = tf.convert_to_tensor(_item_emb) + + def build_predict_graph(self): + user_emb = tf.nn.embedding_lookup(self._user_emb, self._user_ids) + item_emb = tf.nn.embedding_lookup(self._item_emb, self._item_ids) + nbr_emb = tf.nn.embedding_lookup(self._item_emb, self._nbr_ids) + neg_emb = tf.nn.embedding_lookup(self._item_emb, self._neg_ids) + self._prediction_dict['user_emb'] = user_emb + self._prediction_dict['item_emb'] = item_emb + self._prediction_dict['nbr_emb'] = nbr_emb + self._prediction_dict['neg_emb'] = neg_emb + self._prediction_dict['user_embedding'] = tf.reduce_join( + tf.as_string(user_emb), axis=-1, separator=',') + self._prediction_dict['item_embedding'] = tf.reduce_join( + tf.as_string(item_emb), axis=-1, separator=',') + + return self._prediction_dict + + def build_loss_graph(self): + # UltraGCN base u2i + pos_logit = tf.reduce_sum(self._prediction_dict['user_emb'] * self._prediction_dict['item_emb'], axis=-1) + true_xent = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.ones_like(pos_logit), logits=pos_logit) + neg_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['neg_emb'], axis=-1) + negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.zeros_like(neg_logit), logits=neg_logit) + loss_u2i = tf.reduce_sum(true_xent * (1 + 1 / tf.sqrt(self._user_degrees * self._item_degrees))) \ + + self._neg_weight * tf.reduce_sum(tf.reduce_mean(negative_xent, axis=-1)) + # UltraGCN i2i + nbr_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['nbr_emb'], axis=-1) # [batch_size, nbr_num] + nbr_xent = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.ones_like(nbr_logit), logits=nbr_logit) + loss_i2i = tf.reduce_sum(nbr_xent * (1 + self._nbr_weights)) + # regularization + loss_l2 = tf.nn.l2_loss(self._prediction_dict['user_emb']) + tf.nn.l2_loss(self._prediction_dict['item_emb']) +\ + tf.nn.l2_loss(self._prediction_dict['nbr_emb']) + tf.nn.l2_loss(self._prediction_dict['neg_emb']) + + loss = loss_u2i + self._i2i_weight * loss_i2i + self._l2_weight * loss_l2 + return {'cross_entropy': loss} + + def build_metric_graph(self, eval_config): + return {} + + def get_outputs(self): + # emb_1 = tf.reduce_join(tf.as_string(self._prediction_dict['user_embedding']), axis=-1, separator=',') + # emb_2 = tf.reduce_join(tf.as_string(self._prediction_dict['item_embedding'] ), axis=-1, separator=',') + return ['user_embedding','item_embedding'] + + + def build_metric_graph(self, eval_config): + metric_dict = {} + for metric in eval_config.metrics_set: + if metric.WhichOneof('metric') == 'recall_at_topk': + logits = self._prediction_dict['logits'] + label = tf.zeros_like(logits[:, :1], dtype=tf.int64) + metric_dict['recall_at_top%d' % + metric.recall_at_topk.topk] = metrics.recall_at_k( + label, logits, metric.recall_at_topk.topk) + else: + ValueError('invalid metric type: %s' % str(metric)) + return metric_dict \ No newline at end of file diff --git a/easy_rec/python/protos/data_source.proto b/easy_rec/python/protos/data_source.proto index eb90ec6ea..00d360a87 100644 --- a/easy_rec/python/protos/data_source.proto +++ b/easy_rec/python/protos/data_source.proto @@ -25,3 +25,11 @@ message BinaryDataInput { repeated string dense_path = 2; repeated string label_path = 3; } + +message GraphLearnInput { + optional string user_node_input = 1; + optional string item_node_input = 2; + optional string u2i_edge_input = 3; + optional string i2i_edge_input = 4; + optional string u2u_edge_input = 5; +} diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index ecf1acd89..db8c5e253 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -131,6 +131,28 @@ message HardNegativeSamplerV2 { optional string field_delimiter = 12 [default="\001"]; } +// for graph based algorithm: ultragcn. +message UltraGCNSampler { + + optional uint32 user_num = 1 [default=10]; + + optional uint32 item_num = 2 [default=10]; + + optional uint32 output_dim = 3 [default=10]; + + optional uint32 nbr_num = 4 [default=10]; + + optional uint32 neg_num = 5 [default=10]; + + optional float neg_weight = 6 [default=10]; + + optional float i2i_weight = 7 [default=10]; + + optional float l2_weight = 8 [default=10]; + + optional string neg_sampler = 9 [default = 'random']; + +} message DatasetConfig { // mini batch size to use for training and evaluation. optional uint32 batch_size = 1 [default = 32]; @@ -218,6 +240,7 @@ message DatasetConfig { HiveRTPInput = 17; HiveParquetInput = 18; CriteoInput = 1001; + GraphInput = 19; } required InputType input_type = 10; @@ -288,6 +311,7 @@ message DatasetConfig { NegativeSamplerInMemory negative_sampler_in_memory = 105; } optional uint32 eval_batch_size = 1001 [default = 4096]; - + + optional UltraGCNSampler ultra_gcn_sampler = 26; } diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index d0df557d7..c8db77fe6 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -19,6 +19,7 @@ import "easy_rec/python/protos/dcn.proto"; import "easy_rec/python/protos/cmbf.proto"; import "easy_rec/python/protos/autoint.proto"; import "easy_rec/python/protos/mind.proto"; +import "easy_rec/python/protos/ultragcn.proto"; import "easy_rec/python/protos/loss.proto"; import "easy_rec/python/protos/rocket_launching.proto"; import "easy_rec/python/protos/variational_dropout.proto"; @@ -67,6 +68,7 @@ message EasyRecModel { MIND mind = 202; DropoutNet dropoutnet = 203; CoMetricLearningI2I metric_learning = 204; + ULTRAGCN ultragcn = 205; MMoE mmoe = 301; ESMM esmm = 302; diff --git a/easy_rec/python/protos/pipeline.proto b/easy_rec/python/protos/pipeline.proto index 301b9bcfa..07b74c5fd 100644 --- a/easy_rec/python/protos/pipeline.proto +++ b/easy_rec/python/protos/pipeline.proto @@ -19,6 +19,7 @@ message EasyRecConfig { DatahubServer datahub_train_input = 12; HiveConfig hive_train_input = 101; BinaryDataInput binary_train_input = 102; + GraphLearnInput graph_train_input_path = 103; } oneof eval_path { string eval_input_path = 3; @@ -26,6 +27,7 @@ message EasyRecConfig { DatahubServer datahub_eval_input = 13; HiveConfig hive_eval_input= 201; BinaryDataInput binary_eval_input = 202; + GraphLearnInput graph_eval_input_path = 203; } required string model_dir = 5; diff --git a/easy_rec/python/protos/ultragcn.proto b/easy_rec/python/protos/ultragcn.proto new file mode 100644 index 000000000..16d0cc68c --- /dev/null +++ b/easy_rec/python/protos/ultragcn.proto @@ -0,0 +1,14 @@ +syntax = "proto2"; +package protos; + +message ULTRAGCN { + optional float l2_regularization = 2 [default=0.0]; + optional uint32 user_num = 3 [default=1]; + optional uint32 item_num = 4 [default=1]; + optional uint32 output_dim = 5 [default=1]; + optional uint32 nbr_num = 6 [default=1]; + optional uint32 neg_num = 7 [default=1]; + optional float neg_weight = 8 [default=0.0]; + optional float i2i_weight = 9 [default=0.0]; + optional float l2_weight = 10 [default=0.0]; +} diff --git a/easy_rec/python/utils/graph_utils.py b/easy_rec/python/utils/graph_utils.py new file mode 100644 index 000000000..17d0cf471 --- /dev/null +++ b/easy_rec/python/utils/graph_utils.py @@ -0,0 +1,55 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import logging + +from easy_rec.python.utils import pai_util + +def graph_init(graph, tf_config=None): + if tf_config: + if isinstance(tf_config, str) or isinstance(tf_config, type(u'')): + tf_config = json.loads(tf_config) + if 'ps' in tf_config['cluster']: + # ps mode + logging.info('ps mode') + ps_count = len(tf_config['cluster']['ps']) + evaluator_cnt = 1 + # evaluator_cnt = 1 if pai_util.has_evaluator() else 0 + # if evaluator_cnt == 0: + # logging.warning( + # 'evaluator is not set as an client in GraphLearn,' + # 'if you actually set evaluator in TF_CONFIG, please do: export' + # ' HAS_EVALUATOR=1.') + task_count = len(tf_config['cluster']['worker']) + 1 + evaluator_cnt + cluster = {'server_count': ps_count, 'client_count': task_count} + if tf_config['task']['type'] in ['chief', 'master']: + graph.init(cluster=cluster, job_name='client', task_index=0) + elif tf_config['task']['type'] == 'worker': + graph.init( + cluster=cluster, + job_name='client', + task_index=tf_config['task']['index'] + 2) + # TODO(hongsheng.jhs): check cluster has evaluator or not? + elif tf_config['task']['type'] == 'evaluator': + graph.init( + cluster=cluster, + job_name='client', + task_index=tf_config['task']['index'] + 1) + elif tf_config['task']['type'] == 'ps': + graph.init( + cluster=cluster, + job_name='server', + task_index=tf_config['task']['index']) + else: + # worker mode + logging.info('worker mode') + task_count = len(tf_config['cluster']['worker']) + 1 + if tf_config['task']['type'] in ['chief', 'master']: + graph.init(task_index=0, task_count=task_count) + elif tf_config['task']['type'] == 'worker': + graph.init( + task_index=tf_config['task']['index'] + evaluator_cnt, + task_count=task_count) + else: + # local mode + graph.init() diff --git a/pai_jobs/run.py b/pai_jobs/run.py index 957772329..5d17118d0 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -230,10 +230,10 @@ def main(argv): if pipeline_config.fg_json_path: fg_util.load_fg_json_to_config(pipeline_config) - if FLAGS.edit_config_json: - print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json) - config_json = yaml.safe_load(FLAGS.edit_config_json) - config_util.edit_config(pipeline_config, config_json) + # if FLAGS.edit_config_json: + # print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json) + # config_json = yaml.safe_load(FLAGS.edit_config_json) + # config_util.edit_config(pipeline_config, config_json) if FLAGS.model_dir: pipeline_config.model_dir = FLAGS.model_dir @@ -273,6 +273,13 @@ def main(argv): print('[run.py] train_tables: %s' % pipeline_config.train_input_path) print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path) + if FLAGS.edit_config_json: + print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json) + config_json = yaml.safe_load(FLAGS.edit_config_json) + config_util.edit_config(pipeline_config, config_json) + logging.info('edit json complete') + logging.info(pipeline_config) + if FLAGS.fine_tune_checkpoint: pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint diff --git a/samples/model_config/ultragcn_on_graph.config b/samples/model_config/ultragcn_on_graph.config new file mode 100644 index 000000000..1c6620abf --- /dev/null +++ b/samples/model_config/ultragcn_on_graph.config @@ -0,0 +1,82 @@ +model_dir: "experiments/graph_on_ultargcn_ckpt" + +graph_train_input_path { + user_node_input: "easy_rec/data/graph_data/gl_user.txt" + item_node_input: "easy_rec/data/graph_data/gl_item.txt" + u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt" + i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt" +} +graph_eval_input_path { + user_node_input: "easy_rec/data/graph_data/gl_user.txt" + item_node_input: "easy_rec/data/graph_data/gl_item.txt" + u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt" + i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt" +} + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + constant_learning_rate { + learning_rate: 1e-3 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 2000 + save_summary_steps: 100 + sync_replicas: true + num_steps: 20000 +} + +eval_config { +} + +data_config { + input_fields { + input_name: 'id' + input_type: INT64 + } + ultra_gcn_sampler { + nbr_num: 10 + neg_num: 10 + neg_sampler: 'random' + } + + batch_size: 512 + num_epochs: 10 + prefetch_size: 5 + input_type: GraphInput +} + +feature_config: { + features: { + input_names: 'id' + feature_type: IdFeature + embedding_dim: 128 + hash_bucket_size: 100000 + } +} + +model_config:{ + model_class: "ULTRAGCN" + ultragcn { + l2_regularization: 1e-6 + user_num: 52643 + item_num: 91599 + output_dim: 128 + nbr_num: 10 + neg_num: 10 + neg_weight: 10 + i2i_weight: 2.75 + l2_weight: 1e-4 + } + loss_type: SOFTMAX_CROSS_ENTROPY + embedding_regularization: 0.0 +} + +export_config { + +}