-
Notifications
You must be signed in to change notification settings - Fork 28
/
esmm.py
110 lines (94 loc) · 5.1 KB
/
esmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
# @Time : 2021-04-13 14:42
# @Author : WenYi
# @Contact : 1244058349@qq.com
# @Description : script description
import torch
import torch.nn as nn
class ESMM(nn.Module):
def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, hidden_dim=[128, 64], dropouts=[0.5, 0.5],
output_size=1, num_task=2):
"""
esmm model input parameters
:param user_feature_dict: user feature dict include: {feature_name: (feature_unique_num, feature_index)}
:param item_feature_dict: item feature dict include: {feature_name: (feature_unique_num, feature_index)}
:param emb_dim: int, embedding size
:param hidden_dim: list of ctr and ctcvr dnn hidden sizes
:param dropouts: list of ctr and ctcvr dnn drop out probability
:param output_size: int out put size
:param num_task: int default 2 multitask numbers
"""
super(ESMM, self).__init__()
# check input parameters
if user_feature_dict is None or item_feature_dict is None:
raise Exception("input parameter user_feature_dict and item_feature_dict must be not None")
if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict) is False:
raise Exception("input parameter user_feature_dict and item_feature_dict must be dict")
self.user_feature_dict = user_feature_dict
self.item_feature_dict = item_feature_dict
self.num_task = num_task
# embedding初始化
user_cate_feature_nums, item_cate_feature_nums = 0, 0
for user_cate, num in self.user_feature_dict.items():
if num[0] > 1:
user_cate_feature_nums += 1
setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
for item_cate, num in self.item_feature_dict.items():
if num[0] > 1:
item_cate_feature_nums += 1
setattr(self, item_cate, nn.Embedding(num[0], emb_dim))
# user embedding + item embedding
hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) + \
(len(user_feature_dict) - user_cate_feature_nums) + (len(item_feature_dict) - item_cate_feature_nums)
# esmm 独立任务的DNN结构
for i in range(self.num_task):
setattr(self, 'task_{}_dnn'.format(i + 1), nn.ModuleList())
hid_dim = [hidden_size] + hidden_dim
for j in range(len(hid_dim) - 1):
getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_hidden_{}'.format(j),
nn.Linear(hid_dim[j], hid_dim[j + 1]))
getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_batchnorm_{}'.format(j),
nn.BatchNorm1d(hid_dim[j + 1]))
getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_dropout_{}'.format(j),
nn.Dropout(dropouts[j]))
getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('task_last_layer',
nn.Linear(hid_dim[-1], output_size))
def forward(self, x):
assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict)
# embedding
user_embed_list, item_embed_list = list(), list()
for user_feature, num in self.user_feature_dict.items():
if num[0] > 1:
user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
else:
user_embed_list.append(x[:, num[1]].unsqueeze(1))
for item_feature, num in self.item_feature_dict.items():
if num[0] > 1:
item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
else:
item_embed_list.append(x[:, num[1]].unsqueeze(1))
# embedding 融合
user_embed = torch.cat(user_embed_list, axis=1)
item_embed = torch.cat(item_embed_list, axis=1)
# hidden layer
hidden = torch.cat([user_embed, item_embed], axis=1).float()
# task tower
task_outputs = list()
for i in range(self.num_task):
x = hidden
for mod in getattr(self, 'task_{}_dnn'.format(i + 1)):
x = mod(x)
task_outputs.append(x)
return task_outputs
if __name__ == "__main__":
import numpy as np
a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1],
[4, 5, 3, 8, 0.6, 0.43],
[6, 3, 2, 9, 0.12, 0.32],
[9, 1, 1, 1, 0.12, 0.45],
[8, 3, 1, 4, 0.21, 0.67]]))
user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)}
item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)}
esmm = ESMM(user_cate_dict, item_cate_dict)
tasks = esmm(a)
print(tasks)