diff --git a/docs/source/component/custom_op.md b/docs/source/component/custom_op.md new file mode 100644 index 000000000..16a37c223 --- /dev/null +++ b/docs/source/component/custom_op.md @@ -0,0 +1,134 @@ +# 使用自定义 OP + +当内置的tf算子不能满足业务需求,或者通过组合现有算子实现需求的性能较差时,可以考虑自定义tf的OP。 + +1. 实现自定义算子,编译为动态库 + - 参考官方示例:[TensorFlow Custom Op](https://github.com/tensorflow/custom-op/) + - 注意:自定义Op的编译依赖tf版本需要与执行时的tf版本保持一致 + - 您可能需要为离线训练 与 在线推理服务 编译两个不同依赖环境的动态库 + - 在PAI平台上需要依赖 tf 1.12 版本编译(先下载pai-tf的官方镜像) + - 在EAS的 [EasyRec Processor](https://help.aliyun.com/zh/pai/user-guide/easyrec) 中使用自定义Op需要依赖 tf 2.10.1 编译 +1. 在`EasyRec`中使用自定义Op的步骤 + 1. 下载EasyRec的最新[源代码](https://github.com/alibaba/EasyRec) + 1. 把上一步编译好的动态库放到`easy_rec/python/ops/${tf_version}`目录,注意版本要子目录名一致 + 1. 开发一个使用自定义Op的组件 + - 新组件的代码添加到 `easy_rec/python/layers/keras/custom_ops.py` + - `custom_ops.py` 提供了一个自定义Op组件的示例 + - 声明新组件,在`easy_rec/python/layers/keras/__init__.py`文件中添加导出语句 + 1. 编写模型配置文件,使用组件化的方式搭建模型,包含新定义的组件(参考下文) + 1. 运行`pai_jobs/deploy_ext.sh`脚本,打包EasyRec,并把打好的资源包(`easy_rec_ext_${version}_res.tar.gz`)上传到MaxCompute项目空间 + 1. (在DataWorks里 or 用odpscmd客户端工具) 训练 & 评估 & 导出 模型 + +## 导出自定义Op的动态库到 saved_model 的 assets 目录 + +```bash +pai -name easy_rec_ext +-Dcmd='export' +-Dconfig='oss://cold-start/EasyRec/custom_op/pipeline.config' +-Dexport_dir='oss://cold-start/EasyRec/custom_op/export/final_with_lib' +-Dextra_params='--asset_files oss://cold-start/EasyRec/config/libedit_distance.so' +-Dres_project='pai_rec_test_dev' +-Dversion='0.7.5' +-Dbuckets='oss://cold-start/' +-Darn='acs:ram::XXXXXXXXXX:role/aliyunodpspaidefaultrole' +-DossHost='oss-cn-beijing-internal.aliyuncs.com' +; +``` + +**注意**: + +1. 在 训练、评估、导出 命令中需要用`-Dres_project`指定上传easyrec资源包的MaxCompute项目空间名 +1. 在 训练、评估、导出 命令中需要用`-Dversion`指定资源包的版本 +1. asset_files参数指定的动态库会被线上推理服务加载,因此需要在与线上推理服务一致的tf版本上编译。(目前是EAS平台的EasyRec Processor依赖 tf 2.10.1版本)。 + - 如果 asset_files 参数还需要指定其他文件路径(比如 fg.json),多个路径之间用英文逗号隔开。 +1. 再次强调一遍,**导出的动态库依赖的tf版本需要与推理服务依赖的tf版本保持一致** + +## 自定义Op的示例 + +使用自定义OP求两段输入文本的Term匹配率 + +```protobuf +feature_config: { + ... + features: { + feature_name: 'raw_genres' + input_names: 'genres' + feature_type: PassThroughFeature + } + features: { + feature_name: 'raw_title' + input_names: 'title' + feature_type: PassThroughFeature + } +} +model_config: { + model_class: 'RankModel' + model_name: 'MLP' + feature_groups: { + group_name: 'text' + feature_names: 'raw_genres' + feature_names: 'raw_title' + wide_deep: DEEP + } + feature_groups: { + group_name: 'features' + feature_names: 'user_id' + feature_names: 'movie_id' + feature_names: 'gender' + feature_names: 'age' + feature_names: 'occupation' + feature_names: 'zip_id' + feature_names: 'movie_year_bin' + wide_deep: DEEP + } + backbone { + blocks { + name: 'text' + inputs { + feature_group_name: 'text' + } + raw_input { + } + } + blocks { + name: 'match_ratio' + inputs { + block_name: 'text' + } + keras_layer { + class_name: 'OverlapFeature' + overlap { + separator: " " + default_value: "0" + methods: "query_common_ratio" + } + } + } + blocks { + name: 'mlp' + inputs { + feature_group_name: 'features' + } + inputs { + block_name: 'match_ratio' + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: [256, 128] + } + } + } + } + model_params { + l2_regularization: 1e-5 + } + embedding_regularization: 1e-6 +} +``` + +1. 如果自定义Op需要处理原始输入特征,则在定义特征时指定 `feature_type: PassThroughFeature` + - 非 `PassThroughFeature` 类型的特征会在预处理阶段做一些变换,组件代码里拿不到原始值 +1. 自定义Op需要处理的原始输入特征按照顺序放置到同一个`feature group`内 +1. 配置一个类型为`raw_input`的输入组件,获取原始输入特征 + - 这是目前EasyRec支持的读取原始输入特征的唯一方式 diff --git a/docs/source/kd.md b/docs/source/kd.md index 4ec2c4ae5..7e2de1270 100644 --- a/docs/source/kd.md +++ b/docs/source/kd.md @@ -20,7 +20,7 @@ - label_is_logits: 目标是logits, 还是probs, 默认是logits -- loss_type: loss的类型, 可以是CROSS_ENTROPY_LOSS或者L2_LOSS +- loss_type: loss的类型, 可以是CROSS_ENTROPY_LOSS、L2_LOSS、BINARY_CROSS_ENTROPY_LOSS、KL_DIVERGENCE_LOSS、PAIRWISE_HINGE_LOSS、LISTWISE_RANK_LOSS等 - loss_weight: loss的权重, 默认是1.0 @@ -63,6 +63,45 @@ model_config { } ``` +除了常规的从teacher模型的预测结果里"蒸馏"知识到student模型,在搜推场景中更加推荐采用基于pairwise或者listwise的方式从teacher模型学习 +其对不同item的排序(学习对item预估结果的偏序关系),示例如下: + +- pairwise 知识蒸馏 + +```protobuf + kd { + loss_name: 'ctcvr_rank_loss' + soft_label_name: 'pay_logits' + pred_name: 'logits' + loss_type: PAIRWISE_HINGE_LOSS + loss_weight: 1.0 + pairwise_hinge_loss { + session_name: "raw_query" + use_exponent: false + use_label_margin: true + } + } +``` + +- listwise 知识蒸馏 + +```protobuf + kd { + loss_name: 'ctcvr_rank_loss' + soft_label_name: 'pay_logits' + pred_name: 'logits' + loss_type: LISTWISE_RANK_LOSS + loss_weight: 1.0 + listwise_rank_loss { + session_name: "raw_query" + temperature: 3.0 + label_is_logits: true + } + } +``` + +可以为损失函数配置参数,配置方法参考[损失函数](models/loss.md)参数。 + ### 训练命令 训练命令不改变, 详细参考[模型训练](./train.md) diff --git a/docs/source/models/bst.md b/docs/source/models/bst.md index e6a62fda3..681b08a53 100644 --- a/docs/source/models/bst.md +++ b/docs/source/models/bst.md @@ -158,8 +158,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { @@ -219,6 +219,7 @@ model_config: { - feature_groups: 特征组 - 包含两个feature_group: dense 和sparse group - wide_deep: BST模型使用的都是Deep features, 所以都设置成DEEP + - 序列组件对应的feature_group的配置方式请查看 [参考文档](../component/sequence.md) - backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md) - blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图 - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 diff --git a/docs/source/models/cl4srec.md b/docs/source/models/cl4srec.md index 665336c46..e08558566 100644 --- a/docs/source/models/cl4srec.md +++ b/docs/source/models/cl4srec.md @@ -157,6 +157,7 @@ model_config: { - use_package_input: 当`package`的输入是动态的时,设置该输入占位符,表示当前`block`的输入由调用`package`时指定 - keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer) - SeqAugment: 序列数据增强的组件,参数详见[参考文档](../component/component.md#id5) + - SeqAugmentOps: `class_name`指定为`SeqAugmentOps`可以使用自定义OP版本的序列数据增加组件,性能更好 - AuxiliaryLoss: 计算辅助任务损失函数的组件,参数详见[参考文档](../component/component.md#id7) - concat_blocks: DAG的输出节点由`concat_blocks`配置项定义,如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。 - model_params: diff --git a/docs/source/models/din.md b/docs/source/models/din.md index b86444c81..b54f4f363 100644 --- a/docs/source/models/din.md +++ b/docs/source/models/din.md @@ -133,8 +133,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { @@ -192,6 +192,7 @@ model_config: { - feature_groups: 特征组 - 包含两个feature_group: dense 和sparse group - wide_deep: DIN模型使用的都是Deep features, 所以都设置成DEEP + - 序列组件对应的feature_group的配置方式请查看 [参考文档](../component/sequence.md) - backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md) - blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图 - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 diff --git a/docs/source/models/loss.md b/docs/source/models/loss.md index e5d81bae8..d0c028d5d 100644 --- a/docs/source/models/loss.md +++ b/docs/source/models/loss.md @@ -10,6 +10,8 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 | L2_LOSS | 平方损失 | | SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 | | CROSS_ENTROPY_LOSS | log loss 负对数损失 | +| BINARY_CROSS_ENTROPY_LOSS | 仅用在知识蒸馏中的BCE损失 | +| KL_DIVERGENCE_LOSS | 仅用在知识蒸馏中的KL散度损失 | | CIRCLE_LOSS | CoMetricLearningI2I模型专用 | | MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 | | SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 | @@ -17,9 +19,12 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 | PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss | | PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 | | PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 | +| PAIRWISE_HINGE_LOSS | pair粒度的hinge loss, 支持自定义pair分组 | | JRC_LOSS | 二分类 + listwise ranking loss | | F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 | | ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 | +| LISTWISE_RANK_LOSS | listwise的排序损失 | +| LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 | - 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING - 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) , @@ -99,6 +104,16 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - margin: 当pair的logit之差减去该参数值后再参与计算,即正负样本的logit之差至少要大于margin,默认值为0 - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 +- PAIRWISE_HINGE_LOSS 的参数配置 + + - session_name: pair分组的字段名,比如user_id + - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 + - margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0 + - ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0 + - label_is_logits: bool, 标记label是否为teacher模型的输出logits,默认为true + - use_label_margin: bool, 是否使用输入pair的label的diff作为margin,设置为true时`margin`参数不生效,默认为true + - use_exponent: bool, 是否对模型的输出做pairwise的指数变化,默认为false + 备注:上述 PAIRWISE\_\*\_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大 - BINARY_FOCAL_LOSS 的参数配置 @@ -115,6 +130,13 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - 参考论文:《 [Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model](https://arxiv.org/pdf/2208.06164.pdf) 》 - 使用示例: [dbmtl_with_jrc_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config) +- LISTWISE_RANK_LOSS 的参数配置 + + - temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0 + - session_name: list分组的字段名,比如user_id + - label_is_logits: bool, 标记label是否为teacher模型的输出logits,默认为false + - scale_logits: bool, 是否需要对模型的logits进行线性缩放,默认为false + 排序模型同时使用多个损失函数的完整示例: [cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config) @@ -159,5 +181,6 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 ### 参考论文: - 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》 -- 《 [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603) 》 +- [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603) - [AITM: Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489.pdf) +- [Pairwise Ranking Distillation for Deep Face Recognition](https://ceur-ws.org/Vol-2744/paper30.pdf) diff --git a/easy_rec/python/builders/loss_builder.py b/easy_rec/python/builders/loss_builder.py index ec4ab57c8..36cdd95b4 100644 --- a/easy_rec/python/builders/loss_builder.py +++ b/easy_rec/python/builders/loss_builder.py @@ -6,7 +6,10 @@ from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits from easy_rec.python.loss.jrc_loss import jrc_loss +from easy_rec.python.loss.listwise_loss import listwise_distill_loss +from easy_rec.python.loss.listwise_loss import listwise_rank_loss from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss +from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss from easy_rec.python.loss.pairwise_loss import pairwise_logistic_loss from easy_rec.python.loss.pairwise_loss import pairwise_loss from easy_rec.python.protos.loss_pb2 import LossType @@ -36,6 +39,9 @@ def build(loss_type, labels=label, logits=pred, weights=loss_weight, **kwargs) elif loss_type == LossType.CROSS_ENTROPY_LOSS: return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs) + elif loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS: + losses = tf.keras.backend.binary_crossentropy(label, pred, from_logits=True) + return tf.reduce_mean(losses) elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]: logging.info('%s is used' % LossType.Name(loss_type)) return tf.losses.mean_squared_error( @@ -72,6 +78,7 @@ def build(loss_type, hinge_margin = None if loss_param is not None and loss_param.HasField('hinge_margin'): hinge_margin = loss_param.hinge_margin + lbl_margin = False if loss_param is None else loss_param.use_label_margin return pairwise_logistic_loss( label, pred, @@ -80,6 +87,30 @@ def build(loss_type, hinge_margin=hinge_margin, ohem_ratio=ohem_ratio, weights=loss_weight, + use_label_margin=lbl_margin, + name=loss_name) + elif loss_type == LossType.PAIRWISE_HINGE_LOSS: + session = kwargs.get('session_ids', None) + temp, ohem_ratio, margin = 1.0, 1.0, 1.0 + label_is_logits, use_label_margin, use_exponent = True, True, False + if loss_param is not None: + temp = loss_param.temperature + ohem_ratio = loss_param.ohem_ratio + margin = loss_param.margin + label_is_logits = loss_param.label_is_logits + use_label_margin = loss_param.use_label_margin + use_exponent = loss_param.use_exponent + return pairwise_hinge_loss( + label, + pred, + session_ids=session, + temperature=temp, + margin=margin, + ohem_ratio=ohem_ratio, + weights=loss_weight, + label_is_logits=label_is_logits, + use_label_margin=use_label_margin, + use_exponent=use_exponent, name=loss_name) elif loss_type == LossType.PAIRWISE_FOCAL_LOSS: session = kwargs.get('session_ids', None) @@ -100,6 +131,42 @@ def build(loss_type, temperature=loss_param.temperature, weights=loss_weight, name=loss_name) + elif loss_type == LossType.LISTWISE_RANK_LOSS: + session = kwargs.get('session_ids', None) + trans_fn, temp, label_is_logits, scale = None, 1.0, False, False + if loss_param is not None: + temp = loss_param.temperature + label_is_logits = loss_param.label_is_logits + scale = loss_param.scale_logits + if loss_param.HasField('transform_fn'): + trans_fn = loss_param.transform_fn + return listwise_rank_loss( + label, + pred, + session, + temperature=temp, + label_is_logits=label_is_logits, + transform_fn=trans_fn, + scale_logits=scale, + weights=loss_weight) + elif loss_type == LossType.LISTWISE_DISTILL_LOSS: + session = kwargs.get('session_ids', None) + trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False + if loss_param is not None: + temp = loss_param.temperature + label_clip_max_value = loss_param.label_clip_max_value + scale = loss_param.scale_logits + if loss_param.HasField('transform_fn'): + trans_fn = loss_param.transform_fn + return listwise_distill_loss( + label, + pred, + session, + temperature=temp, + label_clip_max_value=label_clip_max_value, + transform_fn=trans_fn, + scale_logits=scale, + weights=loss_weight) elif loss_type == LossType.F1_REWEIGHTED_LOSS: f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square label_smoothing = 0 if loss_param is None else loss_param.label_smoothing @@ -130,13 +197,14 @@ def build(loss_type, raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type)) -def build_kd_loss(kds, prediction_dict, label_dict): +def build_kd_loss(kds, prediction_dict, label_dict, feature_dict): """Build knowledge distillation loss. Args: kds: list of knowledge distillation object of type KD. prediction_dict: dict of predict_name to predict tensors. label_dict: ordered dict of label_name to label tensors. + feature_dict: dict of feature name to feature value Return: knowledge distillation loss will be add to loss_dict with key: kd_loss. @@ -152,35 +220,106 @@ def build_kd_loss(kds, prediction_dict, label_dict): loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_') loss_name += '_' + kd.soft_label_name.replace('/', '_') + loss_weight = kd.loss_weight + if kd.HasField('task_space_indicator_name') and kd.HasField( + 'task_space_indicator_value'): + in_task_space = tf.to_float( + tf.equal(feature_dict[kd.task_space_indicator_name], + kd.task_space_indicator_value)) + loss_weight = loss_weight * ( + kd.in_task_space_weight * in_task_space + kd.out_task_space_weight * + (1 - in_task_space)) + label = label_dict[kd.soft_label_name] pred = prediction_dict[kd.pred_name] + epsilon = tf.keras.backend.epsilon() + num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1] - if kd.loss_type == LossType.CROSS_ENTROPY_LOSS: + if kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS: + if not kd.label_is_logits: # label is prob + label = tf.clip_by_value(label, epsilon, 1 - epsilon) + label = tf.log(label / (1 - label)) + if not kd.pred_is_logits: + pred = tf.clip_by_value(pred, epsilon, 1 - epsilon) + pred = tf.log(pred / (1 - pred)) + if kd.temperature > 0: + label = label / kd.temperature + pred = pred / kd.temperature + label = tf.nn.sigmoid(label) # convert to prob + elif kd.loss_type == LossType.KL_DIVERGENCE_LOSS: + if not kd.label_is_logits: # label is prob + if num_class == 1: # for binary classification + label = tf.clip_by_value(label, epsilon, 1 - epsilon) + label = tf.log(label / (1 - label)) + else: + label = tf.math.log(label + epsilon) + label -= tf.reduce_max(label) + if not kd.pred_is_logits: + if num_class == 1: # for binary classification + pred = tf.clip_by_value(pred, epsilon, 1 - epsilon) + pred = tf.log(pred / (1 - pred)) + else: + pred = tf.math.log(pred + epsilon) + pred -= tf.reduce_max(pred) + if kd.temperature > 0: + label = label / kd.temperature + pred = pred / kd.temperature + if num_class > 1: + label = tf.nn.softmax(label) + pred = tf.nn.softmax(pred) + else: + label = tf.nn.sigmoid(label) # convert to prob + pred = tf.nn.sigmoid(pred) # convert to prob + elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS: if not kd.label_is_logits: - label = tf.math.log(label + 1e-7) + label = tf.math.log(label + epsilon) if not kd.pred_is_logits: - pred = tf.math.log(pred + 1e-7) - - if kd.temperature > 0 and kd.loss_type == LossType.CROSS_ENTROPY_LOSS: - label = label / kd.temperature - pred = pred / kd.temperature - - if kd.loss_type == LossType.CROSS_ENTROPY_LOSS: - num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1] + pred = tf.math.log(pred + epsilon) + if kd.temperature > 0: + label = label / kd.temperature + pred = pred / kd.temperature if num_class > 1: label = tf.nn.softmax(label) pred = tf.nn.softmax(pred) elif num_class == 1: label = tf.nn.sigmoid(label) - pred = tf.nn.sigmoid(label) + pred = tf.nn.sigmoid(pred) - if kd.loss_type == LossType.CROSS_ENTROPY_LOSS: + if kd.loss_type == LossType.KL_DIVERGENCE_LOSS: + if num_class == 1: + label = tf.expand_dims(label, 1) # [B, 1] + labels = tf.concat([1 - label, label], axis=1) # [B, 2] + pred = tf.expand_dims(pred, 1) # [B, 1] + preds = tf.concat([1 - pred, pred], axis=1) # [B, 2] + else: + labels = label + preds = pred + losses = tf.keras.losses.KLD(labels, preds) + loss_dict[loss_name] = tf.reduce_mean( + losses, name=loss_name) * loss_weight + elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS: + losses = tf.keras.backend.binary_crossentropy( + label, pred, from_logits=True) + loss_dict[loss_name] = tf.reduce_mean( + losses, name=loss_name) * loss_weight + elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS: loss_dict[loss_name] = tf.losses.log_loss( - label, pred, weights=kd.loss_weight) + label, pred, weights=loss_weight) elif kd.loss_type == LossType.L2_LOSS: loss_dict[loss_name] = tf.losses.mean_squared_error( - labels=label, predictions=pred, weights=kd.loss_weight) + labels=label, predictions=pred, weights=loss_weight) else: - assert False, 'unsupported loss type for kd: %s' % LossType.Name( - kd.loss_type) + loss_param = kd.WhichOneof('loss_param') + kwargs = {} + if loss_param is not None: + loss_param = getattr(kd, loss_param) + if hasattr(loss_param, 'session_name'): + kwargs['session_ids'] = feature_dict[loss_param.session_name] + loss_dict[loss_name] = build( + kd.loss_type, + label, + pred, + loss_weight=loss_weight, + loss_param=loss_param, + **kwargs) return loss_dict diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fc850fb62..fe4c12132 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time +from distutils.version import LooseVersion import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py index 5fa7059c2..1b74ea344 100644 --- a/easy_rec/python/feature_column/feature_column.py +++ b/easy_rec/python/feature_column/feature_column.py @@ -132,7 +132,7 @@ def _cmp_embed_config(a, b): self.parse_sequence_feature(config) elif config.feature_type == config.ExprFeature: self.parse_expr_feature(config) - else: + elif config.feature_type != config.PassThroughFeature: assert False, 'invalid feature type: %s' % config.feature_type except FeatureKeyError: pass diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index fba819d69..93220f047 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -19,6 +19,7 @@ from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import signature_constants +import easy_rec from easy_rec.python.utils import numpy_utils from easy_rec.python.utils.config_util import get_configs_from_pipeline_file from easy_rec.python.utils.config_util import get_input_name_from_fg_json @@ -26,6 +27,11 @@ from easy_rec.python.utils.input_utils import get_type_defaults from easy_rec.python.utils.load_class import get_register_class_meta +try: + tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libcustom_ops.so')) +except Exception as ex: + logging.warning('exception: %s' % str(ex)) + if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index d94b1de13..f53c4ee45 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -921,8 +921,16 @@ def _preprocess(self, field_dict): ], 'invalid label dtype: %s' % str(field_dict[input_name].dtype) label_dict[input_name] = field_dict[input_name] - if self._data_config.HasField('sample_weight'): - if self._mode != tf.estimator.ModeKeys.PREDICT: + if self._mode != tf.estimator.ModeKeys.PREDICT: + for func_config in self._data_config.extra_label_func: + lbl_name = func_config.label_name + func_name = func_config.label_func + logging.info('generating new label `%s` by transform: %s' % + (lbl_name, func_name)) + lbl_fn = load_by_path(func_name) + label_dict[lbl_name] = lbl_fn(label_dict) + + if self._data_config.HasField('sample_weight'): parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[ self._data_config.sample_weight] diff --git a/easy_rec/python/layers/backbone.py b/easy_rec/python/layers/backbone.py index e9425ed4d..5c503dd31 100644 --- a/easy_rec/python/layers/backbone.py +++ b/easy_rec/python/layers/backbone.py @@ -56,7 +56,7 @@ def __init__(self, config, features, input_layer, l2_reg=None): self._dag.add_node(block.name) self._name_to_blocks[block.name] = block layer = block.WhichOneof('layer') - if layer == 'input_layer': + if layer in {'input_layer', 'raw_input'}: if len(block.inputs) != 1: raise ValueError('input layer `%s` takes only one input' % block.name) one_input = block.inputs[0] @@ -71,8 +71,11 @@ def __init__(self, config, features, input_layer, l2_reg=None): if group in input_feature_groups: logging.warning('input `%s` already exists in other block' % group) else: - input_fn = EnhancedInputLayer(self._input_layer, self._features, - group, reuse) + if layer == 'input_layer': + input_fn = EnhancedInputLayer(self._input_layer, self._features, + group, reuse) + else: + input_fn = self._input_layer.get_raw_features(self._features, group) input_feature_groups[group] = input_fn self._name_to_layer[block.name] = input_fn else: @@ -91,7 +94,7 @@ def __init__(self, config, features, input_layer, l2_reg=None): num_pkg_input = 0 for block in config.blocks: layer = block.WhichOneof('layer') - if layer == 'input_layer': + if layer in {'input_layer', 'raw_input'}: continue name = block.name if name in input_feature_groups: @@ -270,6 +273,8 @@ def call(self, is_training, **kwargs): if layer is None: # identity layer output = self.block_input(config, block_outputs, is_training, **kwargs) block_outputs[block] = output + elif layer == 'raw_input': + block_outputs[block] = self._name_to_layer[block] elif layer == 'input_layer': input_fn = self._name_to_layer[block] input_config = config.input_layer @@ -440,7 +445,7 @@ def __call__(self, is_training, **kwargs): params = Parameter.make_from_pb(self._config.top_mlp) params.l2_regularizer = self._l2_reg final_mlp = MLP(params, name='backbone_top_mlp') - output = final_mlp(output, training=is_training) + output = final_mlp(output, training=is_training, **kwargs) return output @classmethod diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py index b704190ce..088fb79e3 100644 --- a/easy_rec/python/layers/input_layer.py +++ b/easy_rec/python/layers/input_layer.py @@ -190,6 +190,21 @@ def get_sequence_feature(self, features, group_name): self._embedding_regularizer, weights_list=embedding_reg_lst) return seq_features + def get_raw_features(self, features, group_name): + """Get features by group_name. + + Args: + features: input tensor dict + group_name: feature_group name + + Return: + features: all raw features in list + """ + assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % ( + group_name, ','.join([x for x in self._feature_groups])) + feature_group = self._feature_groups[group_name] + return [features[x] for x in feature_group.feature_names] + def __call__(self, features, group_name, is_combine=True, is_dict=False): """Get features by group_name. diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index f029b9c66..ffbd5c8c3 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -5,6 +5,11 @@ from .blocks import Highway from .blocks import TextCNN from .bst import BST +from .custom_ops import EditDistance +from .custom_ops import MappedDotProduct +from .custom_ops import OverlapFeature +from .custom_ops import SeqAugmentOps +from .custom_ops import TextNormalize from .data_augment import SeqAugment from .din import DIN from .fibinet import BiLinear diff --git a/easy_rec/python/layers/keras/attention.py b/easy_rec/python/layers/keras/attention.py index d7f717cb5..cf686356f 100644 --- a/easy_rec/python/layers/keras/attention.py +++ b/easy_rec/python/layers/keras/attention.py @@ -101,7 +101,7 @@ def build(self, input_shape): dtype=self.dtype, trainable=True, ) - self.built = True + super(Attention, self).build(input_shape) # Be sure to call this somewhere! def _calculate_scores(self, query, key): """Calculates attention scores as a query-key dot product. diff --git a/easy_rec/python/layers/keras/blocks.py b/easy_rec/python/layers/keras/blocks.py index 13cd14612..5c411922c 100644 --- a/easy_rec/python/layers/keras/blocks.py +++ b/easy_rec/python/layers/keras/blocks.py @@ -32,6 +32,7 @@ class MLP(Layer): def __init__(self, params, name='mlp', reuse=None, **kwargs): super(MLP, self).__init__(name=name, **kwargs) + self.layer_name = name params.check_required('hidden_units') use_bn = params.get_or_default('use_bn', True) use_final_bn = params.get_or_default('use_final_bn', True) @@ -50,6 +51,7 @@ def __init__(self, params, name='mlp', reuse=None, **kwargs): final_activation, use_bias, initializer, use_bn_after_act)) assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name self.reuse = reuse + self.add_to_outputs = params.get_or_default('add_to_outputs', False) num_dropout = len(dropout_rate) self._sub_layers = [] @@ -119,13 +121,17 @@ def call(self, x, training=None, **kwargs): add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) else: x = layer(x) + if self.add_to_outputs and 'prediction_dict' in kwargs: + outputs = kwargs['prediction_dict'] + outputs[self.layer_name] = tf.squeeze(x, axis=1) + logging.info('add `%s` to model outputs' % self.layer_name) return x class Highway(Layer): def __init__(self, params, name='highway', reuse=None, **kwargs): - super(Highway, self).__init__(name, **kwargs) + super(Highway, self).__init__(name=name, **kwargs) self.emb_size = params.get_or_default('emb_size', None) self.num_layers = params.get_or_default('num_layers', 1) self.activation = params.get_or_default('activation', 'relu') @@ -175,7 +181,7 @@ class Gate(Layer): """Weighted sum gate.""" def __init__(self, params, name='gate', reuse=None, **kwargs): - super(Gate, self).__init__(name, **kwargs) + super(Gate, self).__init__(name=name, **kwargs) self.weight_index = params.get_or_default('weight_index', 0) def call(self, inputs, **kwargs): @@ -203,7 +209,7 @@ class TextCNN(Layer): """ def __init__(self, params, name='text_cnn', reuse=None, **kwargs): - super(TextCNN, self).__init__(name, **kwargs) + super(TextCNN, self).__init__(name=name, **kwargs) self.config = params.get_pb_config() self.pad_seq_length = self.config.pad_sequence_length if self.pad_seq_length <= 0: diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py new file mode 100644 index 000000000..f0b04b2ab --- /dev/null +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -0,0 +1,247 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Convenience blocks for using custom ops.""" +import logging +import os + +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import Layer + +curr_dir, _ = os.path.split(__file__) +parent_dir = os.path.dirname(curr_dir) +ops_idr = os.path.dirname(parent_dir) +ops_dir = os.path.join(ops_idr, 'ops') +if 'PAI' in tf.__version__: + ops_dir = os.path.join(ops_dir, '1.12_pai') +elif tf.__version__.startswith('1.12'): + ops_dir = os.path.join(ops_dir, '1.12') +elif tf.__version__.startswith('1.15'): + if 'IS_ON_PAI' in os.environ: + ops_dir = os.path.join(ops_dir, 'DeepRec') + else: + ops_dir = os.path.join(ops_dir, '1.15') +elif tf.__version__.startswith('2.12'): + ops_dir = os.path.join(ops_dir, '2.12') + +logging.info('ops_dir is %s' % ops_dir) +custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so') +try: + custom_ops = tf.load_op_library(custom_op_path) + logging.info('load custom op from %s succeed' % custom_op_path) +except Exception as ex: + logging.warning('load custom op from %s failed: %s' % + (custom_op_path, str(ex))) + custom_ops = None + +# if tf.__version__ >= '2.0': +# tf = tf.compat.v1 + + +class SeqAugmentOps(Layer): + """Do data augmentation for input sequence embedding.""" + + def __init__(self, params, name='sequence_aug', reuse=None, **kwargs): + super(SeqAugmentOps, self).__init__(name=name, **kwargs) + self.reuse = reuse + self.seq_aug_params = params.get_pb_config() + self.seq_augment = custom_ops.my_seq_augment + + def call(self, inputs, training=None, **kwargs): + assert isinstance( + inputs, + (list, tuple)), 'the inputs of SeqAugmentOps must be type of list/tuple' + assert len(inputs) >= 2, 'SeqAugmentOps must have at least 2 inputs' + seq_input, seq_len = inputs[:2] + embedding_dim = int(seq_input.shape[-1]) + with tf.variable_scope(self.name, reuse=self.reuse): + mask_emb = tf.get_variable( + 'mask', (embedding_dim,), dtype=tf.float32, trainable=True) + seq_len = tf.to_int32(seq_len) + aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb, + self.seq_aug_params.crop_rate, + self.seq_aug_params.reorder_rate, + self.seq_aug_params.mask_rate) + return aug_seq, aug_len + + +class TextNormalize(Layer): + + def __init__(self, params, name='text_normalize', reuse=None, **kwargs): + super(TextNormalize, self).__init__(name=name, **kwargs) + self.txt_normalizer = custom_ops.text_normalize_op + self.norm_parameter = params.get_or_default('norm_parameter', 0) + self.remove_space = params.get_or_default('remove_space', False) + + def call(self, inputs, training=None, **kwargs): + inputs = inputs if type(inputs) in (tuple, list) else [inputs] + result = [ + self.txt_normalizer( + txt, parameter=self.norm_parameter, remove_space=self.remove_space) + for txt in inputs + ] + if len(result) == 1: + return result[0] + return result + + +class MappedDotProduct(Layer): + + def __init__(self, params, name='mapped_dot_product', reuse=None, **kwargs): + super(MappedDotProduct, self).__init__(name=name, **kwargs) + self.mapped_dot_product = custom_ops.mapped_dot_product + self.bucketize = custom_ops.my_bucketize + self.default_value = params.get_or_default('default_value', 0) + self.separator = params.get_or_default('separator', '\035') + self.norm_fn = params.get_or_default('normalize_fn', None) + self.boundaries = list(params.get_or_default('boundaries', [])) + self.emb_dim = params.get_or_default('embedding_dim', 0) + self.print_first_n = params.get_or_default('print_first_n', 0) + self.summarize = params.get_or_default('summarize', None) + if self.emb_dim > 0: + vocab_size = len(self.boundaries) + 1 + with tf.variable_scope(self.name, reuse=reuse): + self.embedding_table = tf.get_variable( + name='dot_product_emb_table', + shape=[vocab_size, self.emb_dim], + dtype=tf.float32) + + def call(self, inputs, training=None, **kwargs): + query, doc = inputs[:2] + with ops.device('/CPU:0'): + feature = self.mapped_dot_product( + query=query, + document=doc, + feature_name=self.name, + separator=self.separator, + default_value=self.default_value) + tf.summary.scalar(self.name, tf.reduce_mean(feature)) + if self.print_first_n: + encode_q = tf.regex_replace(query, self.separator, ' ') + encode_t = tf.regex_replace(query, self.separator, ' ') + feature = tf.Print( + feature, [encode_q, encode_t, feature], + message=self.name, + first_n=self.print_first_n, + summarize=self.summarize) + if self.norm_fn is not None: + fn = eval(self.norm_fn) + feature = fn(feature) + tf.summary.scalar('normalized_%s' % self.name, tf.reduce_mean(feature)) + if self.print_first_n: + feature = tf.Print( + feature, [feature], + message='normalized %s' % self.name, + first_n=self.print_first_n, + summarize=self.summarize) + if self.boundaries: + feature = self.bucketize(feature, boundaries=self.boundaries) + tf.summary.histogram('bucketized_%s' % self.name, feature) + if self.emb_dim > 0 and self.boundaries: + vocab_size = len(self.boundaries) + 1 + one_hot_input_ids = tf.one_hot(feature, depth=vocab_size) + return tf.matmul(one_hot_input_ids, self.embedding_table) + return tf.expand_dims(feature, axis=-1) + + +class OverlapFeature(Layer): + + def __init__(self, params, name='overlap_feature', reuse=None, **kwargs): + super(OverlapFeature, self).__init__(name=name, **kwargs) + self.overlap_feature = custom_ops.overlap_fg_op + methods = params.get_or_default('methods', []) + assert methods, 'overlap feature methods must be set' + self.methods = [str(method) for method in methods] + self.norm_fn = params.get_or_default('normalize_fn', None) + self.boundaries = list(params.get_or_default('boundaries', [])) + self.separator = params.get_or_default('separator', '\035') + self.default_value = params.get_or_default('default_value', '-1') + self.emb_dim = params.get_or_default('embedding_dim', 0) + self.print_first_n = params.get_or_default('print_first_n', 0) + self.summarize = params.get_or_default('summarize', None) + if self.emb_dim > 0: + vocab_size = len(self.boundaries) + 1 + vocab_size *= len(self.methods) + with tf.variable_scope(self.name, reuse=reuse): + self.embedding_table = tf.get_variable( + name='overlap_emb_table', + shape=[vocab_size, self.emb_dim], + dtype=tf.float32) + + def call(self, inputs, training=None, **kwargs): + query, title = inputs[:2] + with ops.device('/CPU:0'): + feature = self.overlap_feature( + query=query, + title=title, + feature_name=self.name, + separator=self.separator, + default_value=self.default_value, + boundaries=self.boundaries, + methods=self.methods, + dtype=tf.int32 if self.boundaries else tf.float32) + + for i, method in enumerate(self.methods): + # warning: feature[:, i] may be not the result of method + if self.boundaries: + tf.summary.histogram('bucketized_%s' % method, feature[:, i]) + else: + tf.summary.scalar(method, tf.reduce_mean(feature[:, i])) + if self.print_first_n: + encode_q = tf.regex_replace(query, self.separator, ' ') + encode_t = tf.regex_replace(query, self.separator, ' ') + feature = tf.Print( + feature, [encode_q, encode_t, feature], + message=self.name, + first_n=self.print_first_n, + summarize=self.summarize) + if self.norm_fn is not None: + fn = eval(self.norm_fn) + feature = fn(feature) + + if self.emb_dim > 0 and self.boundaries: + # This vocab will be small so we always do one-hot here, since it is always + # faster for a small vocabulary. + batch_size = tf.shape(feature)[0] + vocab_size = len(self.boundaries) + 1 + num_indices = len(self.methods) + # Compute offsets, add to every column indices + offsets = tf.range(num_indices) * vocab_size # Shape: [3] + offsets = tf.reshape(offsets, [1, num_indices]) # Shape: [1, 3] + offsets = tf.tile(offsets, + [batch_size, 1]) # Shape: [batch_size, num_indices] + shifted_indices = feature + offsets # Shape: [batch_size, num_indices] + flat_feature_ids = tf.reshape(shifted_indices, [-1]) + one_hot_ids = tf.one_hot(flat_feature_ids, depth=vocab_size * num_indices) + feature_embeddings = tf.matmul(one_hot_ids, self.embedding_table) + feature_embeddings = tf.reshape(feature_embeddings, + [batch_size, num_indices * self.emb_dim]) + return feature_embeddings + return feature + + +class EditDistance(Layer): + + def __init__(self, params, name='edit_distance', reuse=None, **kwargs): + super(EditDistance, self).__init__(name=name, **kwargs) + self.edit_distance = custom_ops.my_edit_distance + self.txt_encoding = params.get_or_default('text_encoding', 'utf-8') + self.emb_size = params.get_or_default('embedding_size', 512) + emb_dim = params.get_or_default('embedding_dim', 4) + with tf.variable_scope(self.name, reuse=reuse): + self.embedding_table = tf.get_variable('embedding_table', + [self.emb_size, emb_dim], + tf.float32) + + def call(self, inputs, training=None, **kwargs): + input1, input2 = inputs[:2] + with ops.device('/CPU:0'): + dist = self.edit_distance( + input1, + input2, + normalize=False, + dtype=tf.int32, + encoding=self.txt_encoding) + ids = tf.clip_by_value(dist, 0, self.emb_size - 1) + embed = tf.nn.embedding_lookup(self.embedding_table, ids) + return embed diff --git a/easy_rec/python/layers/keras/data_augment.py b/easy_rec/python/layers/keras/data_augment.py index f48a27350..6f12f8666 100644 --- a/easy_rec/python/layers/keras/data_augment.py +++ b/easy_rec/python/layers/keras/data_augment.py @@ -99,7 +99,7 @@ class SeqAugment(Layer): """Do data augmentation for input sequence embedding.""" def __init__(self, params, name='seq_aug', reuse=None, **kwargs): - super(SeqAugment, self).__init__(name, **kwargs) + super(SeqAugment, self).__init__(name=name, **kwargs) self.reuse = reuse self.seq_aug_params = params.get_pb_config() diff --git a/easy_rec/python/layers/keras/fibinet.py b/easy_rec/python/layers/keras/fibinet.py index 52c762bb1..220c57cb5 100644 --- a/easy_rec/python/layers/keras/fibinet.py +++ b/easy_rec/python/layers/keras/fibinet.py @@ -30,7 +30,7 @@ class SENet(Layer): """ def __init__(self, params, name='SENet', reuse=None, **kwargs): - super(SENet, self).__init__(name, **kwargs) + super(SENet, self).__init__(name=name, **kwargs) self.config = params.get_pb_config() self.reuse = reuse if tf.__version__ >= '2.0': @@ -58,6 +58,7 @@ def build(self, input_shape): name='W1') self.excite_layer = Dense( units=emb_size, kernel_initializer='glorot_normal', name='W2') + super(SENet, self).build(input_shape) # Be sure to call this somewhere! def call(self, inputs, **kwargs): g = self.config.num_squeeze_group @@ -123,7 +124,7 @@ class BiLinear(Layer): """ def __init__(self, params, name='bilinear', reuse=None, **kwargs): - super(BiLinear, self).__init__(name, **kwargs) + super(BiLinear, self).__init__(name=name, **kwargs) self.reuse = reuse params.check_required(['num_output_units']) bilinear_plus = params.get_or_default('use_plus', True) @@ -169,6 +170,7 @@ def build(self, input_shape): units=int(input_shape[j][-1]), name='interaction_%d_%d' % (i, j)) for i, j in itertools.combinations(range(field_num), 2) ] + super(BiLinear, self).build(input_shape) # Be sure to call this somewhere! def call(self, inputs, **kwargs): embeddings = inputs @@ -210,7 +212,7 @@ class FiBiNet(Layer): """ def __init__(self, params, name='fibinet', reuse=None, **kwargs): - super(FiBiNet, self).__init__(name, **kwargs) + super(FiBiNet, self).__init__(name=name, **kwargs) self.reuse = reuse self._config = params.get_pb_config() diff --git a/easy_rec/python/layers/keras/interaction.py b/easy_rec/python/layers/keras/interaction.py index f3426d02a..b44f96f28 100644 --- a/easy_rec/python/layers/keras/interaction.py +++ b/easy_rec/python/layers/keras/interaction.py @@ -18,7 +18,7 @@ class FM(tf.keras.layers.Layer): """ def __init__(self, params, name='fm', reuse=None, **kwargs): - super(FM, self).__init__(name, **kwargs) + super(FM, self).__init__(name=name, **kwargs) self.use_variant = params.get_or_default('use_variant', False) def call(self, inputs, **kwargs): @@ -67,9 +67,9 @@ class DotInteraction(tf.keras.layers.Layer): """ def __init__(self, params, name=None, reuse=None, **kwargs): + super(DotInteraction, self).__init__(name=name, **kwargs) self._self_interaction = params.get_or_default('self_interaction', False) self._skip_gather = params.get_or_default('skip_gather', False) - super(DotInteraction, self).__init__(name=name, **kwargs) def call(self, inputs, **kwargs): """Performs the interaction operation on the tensors in the list. @@ -187,8 +187,8 @@ class Cross(tf.keras.layers.Layer): Output shape: A single (batch_size, `input_dim`) dimensional output. """ - def __init__(self, params, reuse=None, **kwargs): - super(Cross, self).__init__(**kwargs) + def __init__(self, params, name='cross', reuse=None, **kwargs): + super(Cross, self).__init__(name=name, **kwargs) self._projection_dim = params.get_or_default('projection_dim', None) self._diag_scale = params.get_or_default('diag_scale', 0.0) self._use_bias = params.get_or_default('use_bias', True) @@ -244,7 +244,7 @@ def build(self, input_shape): dtype=self.dtype, activation=self._preactivation, ) - self.built = True + super(Cross, self).build(input_shape) # Be sure to call this somewhere! def call(self, inputs, **kwargs): """Computes the feature cross. diff --git a/easy_rec/python/layers/keras/layer_norm.py b/easy_rec/python/layers/keras/layer_norm.py index 2f64984e5..7d6c81d5f 100644 --- a/easy_rec/python/layers/keras/layer_norm.py +++ b/easy_rec/python/layers/keras/layer_norm.py @@ -259,7 +259,8 @@ def build(self, input_shape): self.beta = None self._fused = self._fused_can_be_used(rank) - self.built = True + super(LayerNormalization, + self).build(input_shape) # Be sure to call this somewhere! def call(self, inputs): # Compute the axes along which to reduce the mean / variance diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index de5624944..08ea698c0 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -25,7 +25,7 @@ class MaskBlock(Layer): """ def __init__(self, params, name='mask_block', reuse=None, **kwargs): - super(MaskBlock, self).__init__(name, **kwargs) + super(MaskBlock, self).__init__(name=name, **kwargs) self.config = params.get_pb_config() self.l2_reg = params.l2_regularizer self._projection_dim = params.get_or_default('projection_dim', None) @@ -33,9 +33,12 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs): self.final_relu = Activation('relu', name='relu') def build(self, input_shape): - assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs' - input_dim = int(input_shape[0][-1]) - mask_input_dim = int(input_shape[1][-1]) + if type(input_shape) in (tuple, list): + assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs' + input_dim = int(input_shape[0][-1]) + mask_input_dim = int(input_shape[1][-1]) + else: + input_dim, mask_input_dim = input_shape[-1], input_shape[-1] if self.config.HasField('reduction_factor'): aggregation_size = int(mask_input_dim * self.config.reduction_factor) elif self.config.HasField('aggregation_size') is not None: @@ -73,9 +76,14 @@ def build(self, input_shape): name='output_ln') else: self.output_layer_norm = LayerNormalization(name='output_ln') + super(MaskBlock, self).build(input_shape) def call(self, inputs, **kwargs): - net, mask_input = inputs + if type(inputs) in (tuple, list): + net, mask_input = inputs[:2] + else: + net, mask_input = inputs, inputs + if self.config.input_layer_norm: net = self.input_layer_norm(net) @@ -103,7 +111,7 @@ class MaskNet(Layer): """ def __init__(self, params, name='mask_net', reuse=None, **kwargs): - super(MaskNet, self).__init__(name, **kwargs) + super(MaskNet, self).__init__(name=name, **kwargs) self.reuse = reuse self.params = params self.config = params.get_pb_config() diff --git a/easy_rec/python/layers/keras/multi_task.py b/easy_rec/python/layers/keras/multi_task.py index ec9f1e5cf..d092fd4d8 100644 --- a/easy_rec/python/layers/keras/multi_task.py +++ b/easy_rec/python/layers/keras/multi_task.py @@ -24,7 +24,7 @@ class MMoE(tf.keras.layers.Layer): """Multi-gate Mixture-of-Experts model.""" def __init__(self, params, name='MMoE', reuse=None, **kwargs): - super(MMoE, self).__init__(name, **kwargs) + super(MMoE, self).__init__(name=name, **kwargs) params.check_required(['num_expert', 'num_task']) self._reuse = reuse self._num_expert = params.num_expert diff --git a/easy_rec/python/layers/keras/numerical_embedding.py b/easy_rec/python/layers/keras/numerical_embedding.py index 5fa32904e..a5205dfed 100644 --- a/easy_rec/python/layers/keras/numerical_embedding.py +++ b/easy_rec/python/layers/keras/numerical_embedding.py @@ -107,7 +107,7 @@ class PeriodicEmbedding(tf.keras.layers.Layer): """ def __init__(self, params, name='periodic_embedding', reuse=None, **kwargs): - super(PeriodicEmbedding, self).__init__(name, **kwargs) + super(PeriodicEmbedding, self).__init__(name=name, **kwargs) self.reuse = reuse params.check_required(['embedding_dim', 'sigma']) self.embedding_dim = int(params.embedding_dim) @@ -164,7 +164,7 @@ class AutoDisEmbedding(tf.keras.layers.Layer): """ def __init__(self, params, name='auto_dis_embedding', reuse=None, **kwargs): - super(AutoDisEmbedding, self).__init__(name, **kwargs) + super(AutoDisEmbedding, self).__init__(name=name, **kwargs) self.reuse = reuse params.check_required(['embedding_dim', 'num_bins', 'temperature']) self.emb_dim = int(params.embedding_dim) diff --git a/easy_rec/python/loss/contrastive_loss.py b/easy_rec/python/loss/contrastive_loss.py index 842c2c695..3fd2be645 100644 --- a/easy_rec/python/loss/contrastive_loss.py +++ b/easy_rec/python/loss/contrastive_loss.py @@ -48,8 +48,8 @@ def info_nce_loss(query, positive, temperature=0.1): def get_mask_matrix(batch_size): - mat = tf.ones((batch_size, batch_size), bool) - diag = tf.fill([batch_size], False) + mat = tf.ones((batch_size, batch_size), dtype=tf.bool) + diag = tf.zeros([batch_size], dtype=tf.bool) mask = tf.linalg.set_diag(mat, diag) mask = tf.tile(mask, [2, 2]) return mask diff --git a/easy_rec/python/loss/listwise_loss.py b/easy_rec/python/loss/listwise_loss.py new file mode 100644 index 000000000..24bd5864f --- /dev/null +++ b/easy_rec/python/loss/listwise_loss.py @@ -0,0 +1,161 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +import tensorflow as tf + +from easy_rec.python.utils.load_class import load_by_path + + +def _list_wise_loss(x, labels, logits, session_ids, label_is_logits): + mask = tf.equal(x, session_ids) + logits = tf.boolean_mask(logits, mask) + labels = tf.boolean_mask(labels, mask) + y = tf.nn.softmax(labels) if label_is_logits else labels + y_hat = tf.nn.log_softmax(logits) + return -tf.reduce_sum(y * y_hat) + + +def _list_prob_loss(x, labels, logits, session_ids): + mask = tf.equal(x, session_ids) + logits = tf.boolean_mask(logits, mask) + labels = tf.boolean_mask(labels, mask) + y = labels / tf.reduce_sum(labels) + y_hat = tf.nn.log_softmax(logits) + return -tf.reduce_sum(y * y_hat) + + +def listwise_rank_loss(labels, + logits, + session_ids, + transform_fn=None, + temperature=1.0, + label_is_logits=False, + scale_logits=False, + weights=1.0, + name='listwise_loss'): + r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. + + Definition: + $$ + \mathcal{L}(\{y\}, \{s\}) = + \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} ) + $$ + + Args: + labels: A `Tensor` of the same shape as `logits` representing graded + relevance. + logits: A `Tensor` with shape [batch_size]. + session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id + transform_fn: an affine transformation function of labels + temperature: (Optional) The temperature to use for scaling the logits. + label_is_logits: Whether `labels` is expected to be a logits tensor. + By default, we consider that `labels` encodes a probability distribution. + scale_logits: Whether to scale the logits. + weights: sample weights + name: the name of loss + """ + loss_name = name if name else 'listwise_rank_loss' + logging.info('[{}] temperature: {}, scale logits: {}'.format( + loss_name, temperature, scale_logits)) + labels = tf.to_float(labels) + if scale_logits: + with tf.variable_scope(loss_name): + w = tf.get_variable( + 'scale_w', + dtype=tf.float32, + shape=(1,), + initializer=tf.ones_initializer()) + b = tf.get_variable( + 'scale_b', + dtype=tf.float32, + shape=(1,), + initializer=tf.zeros_initializer()) + logits = logits * tf.abs(w) + b + if temperature != 1.0: + logits /= temperature + if label_is_logits: + labels /= temperature + if transform_fn is not None: + trans_fn = load_by_path(transform_fn) + labels = trans_fn(labels) + + sessions, _ = tf.unique(tf.squeeze(session_ids)) + tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) + losses = tf.map_fn( + lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits + ), + sessions, + dtype=tf.float32) + if tf.is_numeric_tensor(weights): + logging.error('[%s] use unsupported sample weight' % loss_name) + return tf.reduce_mean(losses) + else: + return tf.reduce_mean(losses) * weights + + +def listwise_distill_loss(labels, + logits, + session_ids, + transform_fn=None, + temperature=1.0, + label_clip_max_value=512, + scale_logits=False, + weights=1.0, + name='listwise_distill_loss'): + r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. + + Definition: + $$ + \mathcal{L}(\{y\}, \{s\}) = + \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} ) + $$ + + Args: + labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model. + logits: A `Tensor` with shape [batch_size]. + session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id + transform_fn: an transformation function of labels. + temperature: (Optional) The temperature to use for scaling the logits. + label_clip_max_value: clip the labels to this value. + scale_logits: Whether to scale the logits. + weights: sample weights + name: the name of loss + """ + loss_name = name if name else 'listwise_rank_loss' + logging.info('[{}] temperature: {}'.format(loss_name, temperature)) + labels = tf.to_float(labels) # supposed to be positions of a teacher model + labels = tf.clip_by_value(labels, 1, label_clip_max_value) + if transform_fn is not None: + trans_fn = load_by_path(transform_fn) + labels = trans_fn(labels) + else: + labels = tf.log1p(label_clip_max_value) - tf.log(labels) + + if scale_logits: + with tf.variable_scope(loss_name): + w = tf.get_variable( + 'scale_w', + dtype=tf.float32, + shape=(1,), + initializer=tf.ones_initializer()) + b = tf.get_variable( + 'scale_b', + dtype=tf.float32, + shape=(1,), + initializer=tf.zeros_initializer()) + logits = logits * tf.abs(w) + b + if temperature != 1.0: + logits /= temperature + + sessions, _ = tf.unique(tf.squeeze(session_ids)) + tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) + losses = tf.map_fn( + lambda x: _list_prob_loss(x, labels, logits, session_ids), + sessions, + dtype=tf.float32) + if tf.is_numeric_tensor(weights): + logging.error('[%s] use unsupported sample weight' % loss_name) + return tf.reduce_mean(losses) + else: + return tf.reduce_mean(losses) * weights diff --git a/easy_rec/python/loss/pairwise_loss.py b/easy_rec/python/loss/pairwise_loss.py index a421cdbba..5b48f8dff 100644 --- a/easy_rec/python/loss/pairwise_loss.py +++ b/easy_rec/python/loss/pairwise_loss.py @@ -132,6 +132,7 @@ def pairwise_logistic_loss(labels, hinge_margin=None, weights=1.0, ohem_ratio=1.0, + use_label_margin=False, name=''): r"""Computes pairwise logistic loss between `labels` and `logits`. @@ -150,6 +151,7 @@ def pairwise_logistic_loss(labels, hinge_margin: the margin between positive and negative logits weights: A scalar, a `Tensor` with shape [batch_size] for each sample ohem_ratio: the percent of hard examples to be mined + use_label_margin: whether to use the diff `label[i]-label[j]` as margin name: the name of loss """ loss_name = name if name else 'pairwise_logistic_loss' @@ -159,11 +161,14 @@ def pairwise_logistic_loss(labels, if temperature != 1.0: logits /= temperature + if use_label_margin: + labels /= temperature pairwise_logits = tf.math.subtract( tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) + pairwise_labels = tf.math.subtract( + tf.expand_dims(labels, -1), tf.expand_dims(labels, 0)) - pairwise_mask = tf.greater( - tf.expand_dims(labels, -1) - tf.expand_dims(labels, 0), 0) + pairwise_mask = tf.greater(pairwise_labels, 0) if hinge_margin is not None: hinge_mask = tf.less(pairwise_logits, hinge_margin) pairwise_mask = tf.logical_and(pairwise_mask, hinge_mask) @@ -173,6 +178,8 @@ def pairwise_logistic_loss(labels, tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0)) pairwise_mask = tf.logical_and(pairwise_mask, group_equal) + if use_label_margin: + pairwise_logits -= pairwise_labels pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask) num_pair = tf.size(pairwise_logits) tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair) @@ -200,3 +207,100 @@ def pairwise_logistic_loss(labels, topk = tf.nn.top_k(losses, k) losses = tf.boolean_mask(topk.values, topk.values > 0) return tf.reduce_mean(losses) + + +def pairwise_hinge_loss(labels, + logits, + session_ids=None, + temperature=1.0, + margin=1.0, + weights=1.0, + ohem_ratio=1.0, + label_is_logits=True, + use_label_margin=True, + use_exponent=False, + name=''): + r"""Computes pairwise hinge loss between `labels` and `logits`. + + Definition: + $$ + \mathcal{L}(\{y\}, \{s\}) = + \sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j)) + $$ + + Args: + labels: A `Tensor` of the same shape as `logits` representing graded + relevance. + logits: A `Tensor` with shape [batch_size]. + session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id + temperature: (Optional) The temperature to use for scaling the logits. + margin: the margin between positive and negative logits + weights: A scalar, a `Tensor` with shape [batch_size] for each sample + ohem_ratio: the percent of hard examples to be mined + label_is_logits: Whether `labels` is expected to be a logits tensor. + use_label_margin: whether to use the diff `label[i]-label[j]` as margin + use_exponent: whether to use exponential difference + name: the name of loss + """ + loss_name = name if name else 'pairwise_hinge_loss' + assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]' + logging.info( + '[{}] margin: {}, ohem_ratio: {}, temperature: {}, use_exponent: {}, label_is_logits: {}, use_label_margin: {}' + .format(loss_name, margin, ohem_ratio, temperature, use_exponent, + label_is_logits, use_label_margin)) + + if temperature != 1.0: + logits /= temperature + if label_is_logits: + labels /= temperature + if use_exponent: + labels = tf.nn.sigmoid(labels) + logits = tf.nn.sigmoid(labels) + + pairwise_logits = tf.math.subtract( + tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) + pairwise_labels = tf.math.subtract( + tf.expand_dims(labels, -1), tf.expand_dims(labels, 0)) + + pairwise_mask = tf.greater(pairwise_labels, 0) + if session_ids is not None: + logging.info('[%s] use session ids' % loss_name) + group_equal = tf.equal( + tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0)) + pairwise_mask = tf.logical_and(pairwise_mask, group_equal) + + pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask) + pairwise_labels = tf.boolean_mask(pairwise_labels, pairwise_mask) + num_pair = tf.size(pairwise_logits) + tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair) + + if use_label_margin: + diff = pairwise_labels - pairwise_logits + else: + diff = margin - pairwise_logits + if use_exponent: + threshold = 88.0 # the max value of float32 is 3.4028235e+38 + safe_diff = tf.clip_by_value(diff, -threshold, threshold) + losses = tf.nn.relu(tf.exp(safe_diff) - 1.0) + else: + losses = tf.nn.relu(diff) + + if tf.is_numeric_tensor(weights): + logging.info('[%s] use sample weight' % loss_name) + weights = tf.expand_dims(tf.cast(weights, tf.float32), -1) + batch_size, _ = get_shape_list(weights, 2) + pairwise_weights = tf.tile(weights, tf.stack([1, batch_size])) + pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask) + else: + pairwise_weights = weights + + if ohem_ratio == 1.0: + return compute_weighted_loss(losses, pairwise_weights) + + losses = compute_weighted_loss( + losses, pairwise_weights, reduction=tf.losses.Reduction.NONE) + k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio) + k = tf.to_int32(tf.math.rint(k)) + topk = tf.nn.top_k(losses, k) + losses = tf.boolean_mask(topk.values, topk.values > 0) + return tf.reduce_mean(losses) diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py index c80d49c95..e35d69030 100644 --- a/easy_rec/python/model/dssm.py +++ b/easy_rec/python/model/dssm.py @@ -62,11 +62,11 @@ def build_predict_graph(self): name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1)) if self._model_config.simi_func == Similarity.COSINE: - user_tower_emb = self.norm(user_tower_emb) - item_tower_emb = self.norm(item_tower_emb) - temperature = self._model_config.temperature - else: - temperature = 1.0 + user_tower_emb = self.norm(user_tower_emb) + item_tower_emb = self.norm(item_tower_emb) + temperature = self._model_config.temperature + else: + temperature = 1.0 user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature if self._model_config.scale_simi: diff --git a/easy_rec/python/model/match_model.py b/easy_rec/python/model/match_model.py index 294bb3dd4..9a3bb9187 100644 --- a/easy_rec/python/model/match_model.py +++ b/easy_rec/python/model/match_model.py @@ -192,7 +192,7 @@ def _build_point_wise_loss_graph(self): # build kd loss kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, - self._labels) + self._labels, self._feature_dict) self._loss_dict.update(kd_loss_dict) return self._loss_dict diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index fa6ce8948..6361138bf 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -42,6 +42,8 @@ def build_predict_graph(self): model = self._model_config.WhichOneof('model') assert model == 'model_params', '`model_params` must be configured' config = self._model_config.model_params + for out in config.outputs: + self._outputs.append(out) self._init_towers(config.task_towers) @@ -187,10 +189,12 @@ def build_loss_weight(self): losses = task_tower_cfg.losses n = len(losses) if n > 0: - loss_weights[tower_name] = [loss.weight for loss in losses] + loss_weights[tower_name] = [ + loss.weight * task_tower_cfg.weight for loss in losses + ] num_loss += n else: - loss_weights[tower_name] = [1.0] + loss_weights[tower_name] = [task_tower_cfg.weight] num_loss += 1 strategy = self._base_model_config.loss_weight_strategy @@ -223,7 +227,7 @@ def build_loss_graph(self): task_loss_weights = self.build_loss_weight() for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name - loss_weight = task_tower_cfg.weight + loss_weight = 1.0 if task_tower_cfg.use_sample_weight: loss_weight *= self._sample_weight @@ -284,13 +288,15 @@ def build_loss_graph(self): self._loss_dict.update(loss_dict) kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, - self._labels) + self._labels, self._feature_dict) self._loss_dict.update(kd_loss_dict) return self._loss_dict def get_outputs(self): outputs = [] + if self._outputs: + outputs.extend(self._outputs) for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name if len(task_tower_cfg.losses) == 0: diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index 79e271483..fc8e5214c 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -27,13 +27,23 @@ def __init__(self, self._num_class = self._model_config.num_class self._losses = self._model_config.losses if self._labels is not None: - self._label_name = list(self._labels.keys())[0] + if model_config.HasField('label_name'): + self._label_name = model_config.label_name + else: + self._label_name = list(self._labels.keys())[0] + self._outputs = [] def build_predict_graph(self): if not self.has_backbone: raise NotImplementedError( 'method `build_predict_graph` must be implemented when backbone network do not exits' ) + model = self._model_config.WhichOneof('model') + assert model == 'model_params', '`model_params` must be configured' + config = self._model_config.model_params + for out in config.outputs: + self._outputs.append(out) + output = self.backbone if int(output.shape[-1]) != self._num_class: logging.info('add head logits layer for rank model') @@ -50,7 +60,9 @@ def _output_to_prediction_impl(self, binary_loss_type = { LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS + LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS, + LossType.LISTWISE_DISTILL_LOSS } if loss_type in binary_loss_type: assert num_class == 1, 'num_class must be 1 when loss type is %s' % loss_type.name @@ -64,6 +76,7 @@ def _output_to_prediction_impl(self, probs = tf.nn.softmax(output, axis=1) tf.summary.scalar('prediction/probs', tf.reduce_mean(probs[:, 1])) prediction_dict['logits' + suffix] = output + prediction_dict['pos_logits' + suffix] = output[:, 1] prediction_dict['probs' + suffix] = probs[:, 1] elif loss_type == LossType.CLASSIFICATION: if num_class == 1: @@ -75,7 +88,9 @@ def _output_to_prediction_impl(self, else: probs = tf.nn.softmax(output, axis=1) prediction_dict['logits' + suffix] = output + prediction_dict['logits' + suffix + '_1'] = output[:, 1] prediction_dict['probs' + suffix] = probs + prediction_dict['probs' + suffix + '_1'] = probs[:, 1] prediction_dict['logits' + suffix + '_y'] = math_ops.reduce_max( output, axis=1) prediction_dict['probs' + suffix + '_y'] = math_ops.reduce_max( @@ -122,7 +137,8 @@ def build_rtp_output_dict(self): LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, - LossType.JRC_LOSS + LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, + LossType.LISTWISE_RANK_LOSS } if loss_types & binary_loss_set: if 'probs' in self._prediction_dict: @@ -163,9 +179,13 @@ def _build_loss_impl(self, binary_loss_type = { LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS + LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS, + LossType.LISTWISE_DISTILL_LOSS } - if loss_type == LossType.CLASSIFICATION: + if loss_type in { + LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS + }: loss_name = loss_name if loss_name else 'cross_entropy_loss' + suffix pred = self._prediction_dict['logits' + suffix] elif loss_type in binary_loss_type: @@ -247,7 +267,7 @@ def build_loss_graph(self): # build kd loss kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, - self._labels) + self._labels, self._feature_dict) self._loss_dict.update(kd_loss_dict) return self._loss_dict @@ -266,7 +286,8 @@ def _build_metric_impl(self, LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, - LossType.JRC_LOSS + LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, + LossType.LISTWISE_RANK_LOSS } metric_dict = {} if metric.WhichOneof('metric') == 'auc': @@ -404,19 +425,23 @@ def build_metric_graph(self, eval_config): def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): binary_loss_set = { - LossType.F1_REWEIGHTED_LOSS, LossType.JRC_LOSS, LossType.PAIR_WISE_LOSS, + LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS + LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS } if loss_type in binary_loss_set: return ['probs' + suffix, 'logits' + suffix] + if loss_type == LossType.JRC_LOSS: + return ['probs' + suffix, 'pos_logits' + suffix] if loss_type == LossType.CLASSIFICATION: if num_class == 1: return ['probs' + suffix, 'logits' + suffix] else: return [ 'y' + suffix, 'probs' + suffix, 'logits' + suffix, - 'probs' + suffix + '_y', 'logits' + suffix + '_y' + 'probs' + suffix + '_y', 'logits' + suffix + '_y', + 'probs' + suffix + '_1', 'logits' + suffix + '_1' ] elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]: return ['y' + suffix] @@ -425,9 +450,14 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): def get_outputs(self): if len(self._losses) == 0: - return self._get_outputs_impl(self._loss_type, self._num_class) + outputs = self._get_outputs_impl(self._loss_type, self._num_class) + if self._outputs: + outputs.extend(self._outputs) + return list(set(outputs)) all_outputs = [] + if self._outputs: + all_outputs.extend(self._outputs) for loss in self._losses: outputs = self._get_outputs_impl(loss.loss_type, self._num_class) all_outputs.extend(outputs) diff --git a/easy_rec/python/ops/1.12/libcustom_ops.so b/easy_rec/python/ops/1.12/libcustom_ops.so new file mode 100755 index 000000000..ee2e09322 Binary files /dev/null and b/easy_rec/python/ops/1.12/libcustom_ops.so differ diff --git a/easy_rec/python/ops/1.12_pai/libcustom_ops.so b/easy_rec/python/ops/1.12_pai/libcustom_ops.so new file mode 100755 index 000000000..76bbc78f4 Binary files /dev/null and b/easy_rec/python/ops/1.12_pai/libcustom_ops.so differ diff --git a/easy_rec/python/ops/1.15/libcustom_ops.so b/easy_rec/python/ops/1.15/libcustom_ops.so new file mode 100755 index 000000000..6fac5578f Binary files /dev/null and b/easy_rec/python/ops/1.15/libcustom_ops.so differ diff --git a/easy_rec/python/ops/2.12/libcustom_ops.so b/easy_rec/python/ops/2.12/libcustom_ops.so new file mode 100755 index 000000000..4739657eb Binary files /dev/null and b/easy_rec/python/ops/2.12/libcustom_ops.so differ diff --git a/easy_rec/python/ops/DeepRec/libcustom_ops.so b/easy_rec/python/ops/DeepRec/libcustom_ops.so new file mode 100755 index 000000000..6fac5578f Binary files /dev/null and b/easy_rec/python/ops/DeepRec/libcustom_ops.so differ diff --git a/easy_rec/python/protos/backbone.proto b/easy_rec/python/protos/backbone.proto index 86589a297..8bf5fd7cc 100644 --- a/easy_rec/python/protos/backbone.proto +++ b/easy_rec/python/protos/backbone.proto @@ -17,6 +17,9 @@ message InputLayer { optional bool concat_seq_feature = 10 [default = true]; } +message RawInputLayer { +} + message Lambda { required string expression = 1; } @@ -78,6 +81,7 @@ message Block { KerasLayer keras_layer = 103; RecurrentLayer recurrent = 104; RepeatLayer repeat = 105; + RawInputLayer raw_input = 106; } } diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 5ffefd064..ca19dcd04 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -176,6 +176,14 @@ message DatasetConfig { // are labels have dimension > 1 repeated uint32 label_dim = 42; + message LabelFunction { + required string label_name = 1; + required string label_func = 2; + } + + // extra transformation functions that generate new labels + repeated LabelFunction extra_label_func = 43; + // whether to shuffle data optional bool shuffle = 5 [default = true]; diff --git a/easy_rec/python/protos/dnn.proto b/easy_rec/python/protos/dnn.proto index 4a9c1743a..10fb6631f 100644 --- a/easy_rec/python/protos/dnn.proto +++ b/easy_rec/python/protos/dnn.proto @@ -29,4 +29,5 @@ message MLP { optional string initializer = 8 [default = 'he_uniform']; optional bool use_bn_after_activation = 9; optional bool use_final_bias = 10 [default = false]; + optional bool add_to_outputs = 11 [default = false]; } diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index b0f79fe0f..56f5b713e 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -35,7 +35,7 @@ message DummyModel { // configure backbone network common parameters message ModelParams { optional float l2_regularization = 1; - + repeated string outputs = 2; repeated BayesTaskTower task_towers = 3; } @@ -49,11 +49,33 @@ message KD { required string soft_label_name = 21; // default to be logits optional bool label_is_logits = 22 [default=true]; - // currently only support CROSS_ENTROPY_LOSS and L2_LOSS required LossType loss_type = 3; optional float loss_weight = 4 [default=1.0]; - // only for loss_type == CROSS_ENTROPY_LOSS + // only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS optional float temperature = 5 [default=1.0]; + // field name for indicating the sample space for this task + optional string task_space_indicator_name = 6; + // field value for indicating the sample space for this task + optional string task_space_indicator_value = 7; + // the loss weight for sample in the task space + optional float in_task_space_weight = 8 [default = 1.0]; + // the loss weight for sample out the task space + optional float out_task_space_weight = 9 [default = 1.0]; + + oneof loss_param { + F1ReweighedLoss f1_reweighted_loss = 101; + SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102; + CircleLoss circle_loss = 103; + MultiSimilarityLoss multi_simi_loss = 104; + BinaryFocalLoss binary_focal_loss = 105; + PairwiseLoss pairwise_loss = 106; + PairwiseFocalLoss pairwise_focal_loss = 107; + PairwiseLogisticLoss pairwise_logistic_loss = 108; + JRCLoss jrc_loss = 109; + PairwiseHingeLoss pairwise_hinge_loss = 110; + ListwiseRankLoss listwise_rank_loss = 111; + ListwiseDistillLoss listwise_distill_loss = 112; + } } message EasyRecModel { @@ -123,4 +145,7 @@ message EasyRecModel { required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed]; optional BackboneTower backbone = 17; + + // label name for rank_model to select one label between multiple labels + optional string label_name = 18; } diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index 19a09f131..8c0c0c214 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -43,6 +43,7 @@ message FeatureConfig { LookupFeature = 4; SequenceFeature = 5; ExprFeature = 6; + PassThroughFeature = 7; } enum FieldType { diff --git a/easy_rec/python/protos/keras_layer.proto b/easy_rec/python/protos/keras_layer.proto index a8b92d1a7..d9b8b2e11 100644 --- a/easy_rec/python/protos/keras_layer.proto +++ b/easy_rec/python/protos/keras_layer.proto @@ -27,5 +27,7 @@ message KerasLayer { PPNet ppnet = 16; TextCNN text_cnn = 17; HighWayTower highway = 18; + OverlapFeature overlap = 19; + MappedDotProduct dot_product = 20; } } diff --git a/easy_rec/python/protos/layer.proto b/easy_rec/python/protos/layer.proto index c0a01686a..63f32a85a 100644 --- a/easy_rec/python/protos/layer.proto +++ b/easy_rec/python/protos/layer.proto @@ -98,3 +98,24 @@ message TextCNN { optional string activation = 4 [default = 'relu']; optional MLP mlp = 5; } + +message OverlapFeature { + optional string separator = 1; + optional string default_value = 2; + repeated string methods = 3; + optional string normalize_fn = 4; + repeated float boundaries = 5; + optional int32 embedding_dim = 6; + optional int32 print_first_n = 7 [default = 0]; + optional int32 summarize = 8; +} + +message MappedDotProduct { + optional string separator = 1; + optional float default_value = 2; + optional string normalize_fn = 3; + repeated float boundaries = 4; + optional int32 embedding_dim = 5; + optional int32 print_first_n = 6 [default = 0]; + optional int32 summarize = 7; +} diff --git a/easy_rec/python/protos/loss.proto b/easy_rec/python/protos/loss.proto index 5098518b3..b377cd75c 100644 --- a/easy_rec/python/protos/loss.proto +++ b/easy_rec/python/protos/loss.proto @@ -16,8 +16,13 @@ enum LossType { BINARY_FOCAL_LOSS = 10; PAIRWISE_FOCAL_LOSS = 11; PAIRWISE_LOGISTIC_LOSS = 12; + PAIRWISE_HINGE_LOSS = 17; JRC_LOSS = 13; ORDER_CALIBRATE_LOSS = 14; + BINARY_CROSS_ENTROPY_LOSS = 15; + KL_DIVERGENCE_LOSS = 16; + LISTWISE_RANK_LOSS = 18; + LISTWISE_DISTILL_LOSS = 19; } message Loss { @@ -35,6 +40,9 @@ message Loss { PairwiseFocalLoss pairwise_focal_loss = 107; PairwiseLogisticLoss pairwise_logistic_loss = 108; JRCLoss jrc_loss = 109; + PairwiseHingeLoss pairwise_hinge_loss = 110; + ListwiseRankLoss listwise_rank_loss = 111; + ListwiseDistillLoss listwise_distill_loss = 112; } }; @@ -87,8 +95,19 @@ message PairwiseFocalLoss { message PairwiseLogisticLoss { required float temperature = 1 [default = 1.0]; optional string session_name = 2; - optional float hinge_margin = 3 [default = 1.0]; + optional float hinge_margin = 3; optional float ohem_ratio = 4 [default = 1.0]; + optional bool use_label_margin = 5 [default = false]; +} + +message PairwiseHingeLoss { + required float temperature = 1 [default = 1.0]; + optional string session_name = 2; + optional float margin = 3 [default = 1.0]; + optional float ohem_ratio = 4 [default = 1.0]; + optional bool label_is_logits = 5 [default = true]; + optional bool use_label_margin = 6 [default = true]; + optional bool use_exponent = 7 [default = false]; } message JRCLoss { @@ -97,3 +116,19 @@ message JRCLoss { optional bool same_label_loss = 3 [default = true]; required string loss_weight_strategy = 4 [default = 'fixed']; } + +message ListwiseRankLoss { + required float temperature = 1 [default = 1.0]; + optional string session_name = 2; + optional string transform_fn = 3; + optional bool label_is_logits = 4 [default = false]; + optional bool scale_logits = 5 [default = false]; +} + +message ListwiseDistillLoss { + required float temperature = 1 [default = 1.0]; + optional string session_name = 2; + optional string transform_fn = 3; + optional float label_clip_max_value = 4 [default = 512.0]; + optional bool scale_logits = 5 [default = false]; +} diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 68d0b8656..0f7b82a28 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest +from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict @@ -409,6 +409,15 @@ def test_highway(self): 'samples/model_config/highway_on_movielens.config', self._test_dir) self.assertTrue(self._success) + # @unittest.skipIf( + # LooseVersion(tf.__version__) >= LooseVersion('2.0.0'), + # 'has no CustomOp when tf version == 2.4') + # def test_custom_op(self): + # self._success = test_utils.test_single_train_eval( + # 'samples/model_config/cl4srec_on_taobao_with_custom_op.config', + # self._test_dir) + # self.assertTrue(self._success) + def test_cdn(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/cdn_on_taobao.config', self._test_dir) diff --git a/easy_rec/version.py b/easy_rec/version.py index 2ae7769a6..2a148c007 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,4 +1,4 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.8.1' +__version__ = '0.8.2' diff --git a/samples/model_config/bst_backbone_on_taobao.config b/samples/model_config/bst_backbone_on_taobao.config index a46f0e09e..b801f87ef 100644 --- a/samples/model_config/bst_backbone_on_taobao.config +++ b/samples/model_config/bst_backbone_on_taobao.config @@ -257,8 +257,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { diff --git a/samples/model_config/cl4srec_on_taobao_with_custom_op.config b/samples/model_config/cl4srec_on_taobao_with_custom_op.config new file mode 100644 index 000000000..18ee40dde --- /dev/null +++ b/samples/model_config/cl4srec_on_taobao_with_custom_op.config @@ -0,0 +1,375 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/cl4srec_on_taobao_custom_op_ckpt" + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.00001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 100 + sync_replicas: True + num_steps: 10 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + + label_fields: 'clk' + batch_size: 1024 + num_epochs: 1 + prefetch_size: 1 + input_type: CSVInput +} + +feature_config: { + features: { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + embedding_name: 'cate_id' + } + features: { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + embedding_name: 'brand' + hash_bucket_size: 100000 + } + features: { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'tag_category_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 10000 + embedding_dim: 16 + embedding_name: 'cate_id' + max_seq_len: 50 + } + features: { + input_names: 'tag_brand_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + embedding_name: 'brand' + max_seq_len: 50 + } + features: { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 + } +} +model_config: { + model_name: 'CL4SRec' + model_class: 'RankModel' + feature_groups: { + group_name: 'item' + feature_names: 'adgroup_id' + feature_names: 'campaign_id' + feature_names: 'cate_id' + feature_names: 'brand' + feature_names: 'customer' + feature_names: 'price' + feature_names: 'pid' + wide_deep: DEEP + } + feature_groups: { + group_name: 'user' + feature_names: 'user_id' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + wide_deep: DEEP + } + feature_groups: { + group_name: 'user_seq' + feature_names: "tag_brand_list" + feature_names: "tag_category_list" + wide_deep: DEEP + } + backbone { + blocks { + name: 'user_seq' + inputs { + feature_group_name: 'user_seq' + } + input_layer { + output_seq_and_normal_feature: true + } + } + packages { + name: 'seq_augment' + blocks { + name: 'augment' + inputs { + block_name: 'user_seq' + } + keras_layer { + class_name: 'SeqAugmentOps' + seq_aug { + mask_rate: 0.6 + crop_rate: 0.2 + reorder_rate: 0.6 + } + } + } + } + packages { + name: 'seq_encoder' + blocks { + name: 'BST' + inputs { + use_package_input: true + } + keras_layer { + class_name: 'BST' + bst { + hidden_size: 128 + num_attention_heads: 2 + num_hidden_layers: 2 + intermediate_size: 128 + hidden_act: 'gelu' + max_position_embeddings: 50 + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0 + output_all_token_embeddings: false + } + } + } + } + blocks { + name: 'contrastive' + inputs { + package_name: 'seq_encoder' + package_input: 'seq_augment' + } + inputs { + package_name: 'seq_encoder' + package_input: 'seq_augment' + } + merge_inputs_into_list: true + keras_layer { + class_name: 'AuxiliaryLoss' + st_params { + fields { + key: 'loss_type' + value: { string_value: 'nce_loss' } + } + fields { + key: 'loss_weight' + value: { number_value: 0.1 } + } + fields { + key: 'temperature' + value: { number_value: 0.15 } + } + } + } + } + blocks { + name: 'main' + inputs { + package_name: 'seq_encoder' + package_input: 'user_seq' + } + inputs { + feature_group_name: 'user' + } + inputs { + feature_group_name: 'item' + } + } + concat_blocks: 'main' + top_mlp { + hidden_units: [256, 64] + } + } + model_params { + l2_regularization: 0 + } + embedding_regularization: 0 +} + +export_config { + multi_placeholder: false +} diff --git a/samples/model_config/din_backbone_on_taobao.config b/samples/model_config/din_backbone_on_taobao.config index d5e705747..7cb48ac56 100644 --- a/samples/model_config/din_backbone_on_taobao.config +++ b/samples/model_config/din_backbone_on_taobao.config @@ -257,8 +257,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { diff --git a/samples/model_config/dssm_distribute_eval_reg_on_taobao.config b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config index b8dd91fe4..093d7c4a5 100644 --- a/samples/model_config/dssm_distribute_eval_reg_on_taobao.config +++ b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config @@ -275,7 +275,7 @@ model_config:{ } embedding_regularization: 5e-5 loss_type: L2_LOSS - + } export_config { diff --git a/setup.cfg b/setup.cfg index b43211827..337833a0f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos