Skip to content

Commit

Permalink
[feat] add vector retrieve (#109)
Browse files Browse the repository at this point in the history
* add vector retrieve

* [feat]: remove graphlearn dependency
  • Loading branch information
yangxudong authored Feb 19, 2022
1 parent f7e36c7 commit 1a44891
Show file tree
Hide file tree
Showing 15 changed files with 566 additions and 19 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

 

## What is EasyRec?
## What is EasyRec?

![intro.png](docs/images/intro.png)

Expand Down Expand Up @@ -58,6 +58,10 @@ EasyRec implements state of the art deep learning models used in common recommed
- Easy to implement [customized models](docs/source/models/user_define.md)
- Not need to care about data pipelines

### Fast vector retrieve

- Run [knn algorithm](docs/source/vector_retrieve.md) of vectors in distribute environment

 

## Get Started
Expand Down
5 changes: 2 additions & 3 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ rank模型中配置相应字段:
- regularization\_lambda: 变分dropout层的正则化系数设置
- embedding\_wise\_variational\_dropout: 变分dropout层维度是否为embedding维度(true:embedding维度;false:feature维度;默认false)




备注:
**这个配在model_config下面,跟model_class平级**


分隔符
Expand Down
6 changes: 5 additions & 1 deletion docs/source/intro.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# EasyRec简介

## What is EasyRec?
## What is EasyRec?

![intro.png](../images/intro.png)

Expand Down Expand Up @@ -56,4 +56,8 @@ EasyRec implements state of the art machine learning models used in common recom
- Easy to implement customized models
- Not need to care about data pipelines

### Fast vector retrieve

- Run [`knn algorithm`](vector_retrieve.md) of vectors in distribute environment

欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796
116 changes: 116 additions & 0 deletions docs/source/vector_retrieve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# 向量近邻检索

## Pai 命令

```sql
pai -name easy_rec_ext -project algo_public
-Dcmd=vector_retrieve
-Dquery_table=odps://pai_online_project/tables/query_vector_table
-Ddoc_table=odps://pai_online_project/tables/doc_vector_table
-Doutput_table=odps://pai_online_project/tables/result_vector_table
-Dcluster='{"worker" : {"count":3, "cpu":600, "gpu":100, "memory":10000}}'
-Dknn_distance=inner_product
-Dknn_num_neighbours=100
-Dknn_feature_dims=128
-Dknn_index_type=gpu_ivfflat
-Dknn_feature_delimiter=','
-Dbuckets='oss://${oss_bucket}/'
-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole'
-DossHost='oss-cn-hangzhou-internal.aliyuncs.com'
```

## 参数说明

| 参数名 | 默认值 | 参数说明 |
| --------------------- | ------------- | ------------------------------------------------------------------------- |
| query_table || 输入查询表, schema: (id bigint, vector string) |
| doc_table || 输入索引表, schema: (id bigint, vector string) |
| output_table || 输出表, schema: (query_id bigint, doc_id bigint, distance double) |
| knn_distance | inner_product | 计算距离的方法:l2、inner_product |
| knn_num_neighbours || top n, 每个query输出多少个近邻 |
| knn_feature_dims || 向量维度 |
| knn_feature_delimiter | , | 向量字符串分隔符 |
| knn_index_type | ivfflat | 向量索引类型:'flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg' |
| knn_nlist | 5 | 聚类的簇个数, number of split cluster on each worker |
| knn_nprobe | 2 | 检索时只考虑距离与输入向量最近的簇个数, number of probe part on each worker |
| knn_compress_dim | 8 | 当index_type为`ivfpq` and `gpu_ivfpq`时, 指定压缩的维度,必须为float属性个数的因子 |

## 使用示例

### 1. 创建查询表

```sql
create table doc_table(pk BIGINT,vector string) partitioned by (pt string);

INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410')
VALUES
(1, '0.1,0.2,-0.4,0.5'),
(2, '-0.1,0.8,0.4,0.5'),
(3, '0.59,0.2,0.4,0.15'),
(10, '0.1,-0.2,0.4,-0.5'),
(20, '-0.1,-0.2,0.4,0.5'),
(30, '0.5,0.2,0.43,0.15')
;
```

### 2. 创建索引表

```sql
create table query_table(pk BIGINT,vector string) partitioned by (pt string);

INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410')
VALUES
(1, '0.1,0.2,0.4,0.5'),
(2, '-0.1,0.2,0.4,0.5'),
(3, '0.5,0.2,0.4,0.5'),
(10, '0.1,0.2,0.4,0.5'),
(20, '-0.1,-0.2,0.4,0.5'),
(30, '0.5,0.2,0.43,0.15')
;
```

### 3. 执行向量检索

```sql
pai -name easy_rec_ext -project algo_public_dev
-Dcmd='vector_retrieve'
-DentryFile='run.py'
-Dquery_table='odps://${project}/tables/query_table/pt=20190410'
-Ddoc_table='odps://${project}/tables/doc_table/pt=20190410'
-Doutput_table='odps://${project}/tables/knn_result_table/pt=20190410'
-Dknn_distance=inner_product
-Dknn_num_neighbours=2
-Dknn_feature_dims=4
-Dknn_index_type='ivfflat'
-Dknn_feature_delimiter=','
-Dbuckets='oss://${oss_bucket}/'
-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole'
-DossHost='oss-cn-shenzhen-internal.aliyuncs.com'
-Dcluster='{
\"worker\" : {
\"count\" : 1,
\"cpu\" : 600
}
}';
;
```

### 4. 查看结果

```sql
SELECT * from knn_result_table where pt='20190410';

-- query doc distance
-- 1 3 0.17999999225139618
-- 1 1 0.13999998569488525
-- 2 2 0.5800000429153442
-- 2 1 0.5600000619888306
-- 3 3 0.5699999928474426
-- 3 30 0.5295000076293945
-- 10 30 0.10700000077486038
-- 10 20 -0.0599999874830246
-- 20 20 0.46000003814697266
-- 20 2 0.3800000250339508
-- 30 3 0.5370000004768372
-- 30 30 0.4973999857902527
```
8 changes: 3 additions & 5 deletions easy_rec/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,7 @@ def parse_sequence_feature(self, config):

assert config.embedding_dim > 0

if config.HasField('sequence_combiner'):
fc.sequence_combiner = config.sequence_combiner
self._deep_columns[feature_name] = fc
else:
self._add_deep_embedding_column(fc, config)
self._add_deep_embedding_column(fc, config)

def _build_partitioner(self, max_partitions):
if max_partitions > 1:
Expand Down Expand Up @@ -477,4 +473,6 @@ def _add_deep_embedding_column(self, fc, config):
if config.feature_type != config.SequenceFeature:
self._deep_columns[feature_name] = fc
else:
if config.HasField('sequence_combiner'):
fc.sequence_combiner = config.sequence_combiner
self._sequence_columns[feature_name] = fc
124 changes: 124 additions & 0 deletions easy_rec/python/inference/vector_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
from datetime import datetime

import common_io
import numpy as np
import tensorflow as tf
try:
import graphlearn as gl
except:
logging.WARN(
'GraphLearn is not installed. You can install it by "pip install http://odps-release.cn-hangzhou.oss-cdn.aliyun-inc.com/graphlearn/tunnel/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl."' # noqa: E501
)


if tf.__version__ >= '2.0':
tf = tf.compat.v1


class VectorRetrieve(object):

def __init__(self,
query_table,
doc_table,
out_table,
ndim,
delimiter=',',
batch_size=4,
index_type='ivfflat',
nlist=10,
nprobe=2,
distance=1,
m=8):
"""Retrieve top n neighbours by query vector.
Args:
query_table: query vector table
doc_table: document vector table
out_table: output table
ndim: int, number of feature dimensions
delimiter: delimiter for feature vectors
batch_size: query batch size
index_type: search model `flat`, `ivfflat`, `ivfpq`, `gpu_ivfflat`
nlist: number of split part on each worker
nprobe: probe part on each worker
distance: type of distance, 0 is l2 distance(default), 1 is inner product.
m: number of dimensions for each node after compress
"""
self.query_table = query_table
self.doc_table = doc_table
self.out_table = out_table
self.ndim = ndim
self.delimiter = delimiter
self.batch_size = batch_size

gl.set_inter_threadnum(8)
gl.set_knn_metric(distance)
knn_option = gl.IndexOption()
knn_option.name = 'knn'
knn_option.index_type = index_type
knn_option.nlist = nlist
knn_option.nprobe = nprobe
knn_option.m = m
self.knn_option = knn_option

def __call__(self, top_n, task_index, task_count, *args, **kwargs):
g = gl.Graph()
g.node(
self.doc_table,
'doc',
decoder=gl.Decoder(
attr_types=['float'] * self.ndim, attr_delimiter=self.delimiter),
option=self.knn_option)
g.init(task_index=task_index, task_count=task_count)

query_reader = common_io.table.TableReader(
self.query_table, slice_id=task_index, slice_count=task_count)
num_records = query_reader.get_row_count()
total_batch_num = num_records // self.batch_size + 1.0
batch_num = 0
print('total input records: {}'.format(query_reader.get_row_count()))
print('total_batch_num: {}'.format(total_batch_num))
print('output_table: {}'.format(self.out_table))

output_table_writer = common_io.table.TableWriter(self.out_table,
task_index)
count = 0
while True:
try:
batch_query_nodes, batch_query_feats = zip(
*query_reader.read(self.batch_size, allow_smaller_final_batch=True))
batch_num += 1.0
print('{} process: {:.2f}'.format(datetime.now().time(),
batch_num / total_batch_num))
feats = to_np_array(batch_query_feats, self.delimiter)
rt_ids, rt_dists = g.search('doc', feats, gl.KnnOption(k=top_n))

for query_node, nodes, dists in zip(batch_query_nodes, rt_ids,
rt_dists):
query = np.array([query_node] * len(nodes), dtype='int64')
output_table_writer.write(
zip(query, nodes, dists), (0, 1, 2), allow_type_cast=False)
count += 1
if np.mod(count, 100) == 0:
print('write ', count, ' query nodes totally')
except Exception as e:
print(e)
break

print('==finished==')
output_table_writer.close()
query_reader.close()
g.close()


def to_np_array(batch_query_feats, attr_delimiter):
return np.array(
[map(float, feat.split(attr_delimiter)) for feat in batch_query_feats],
dtype='float32')
12 changes: 12 additions & 0 deletions easy_rec/python/test/odps_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ def test_boundary_test(self):
tot.start_test()
tot.drop_table()

def test_vector_retrieve(self):
start_files = [
'vector_retrieve/create_inner_vector_table.sql'
]
test_files = [
'vector_retrieve/run_vector_retrieve.sql'
]
end_file = ['vector_retrieve/drop_table.sql']
tot = OdpsTest(start_files, test_files, end_file, odps_oss_config)
tot.start_test()
tot.drop_table()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
Loading

0 comments on commit 1a44891

Please sign in to comment.