diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 52a5811..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 RUCAIBox - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index a36f00f..093c818 100644 --- a/README.md +++ b/README.md @@ -1,286 +1,15 @@ # CRSLab -[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab) -[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases) -[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) -[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939) -[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest) +This branch integrates the new evaluation approach from iEvaLM, currently supporting ChatGPT models and ReDial dataset. We will complete the evaluation for all models and datasets in the future. Please continue to follow us! -[Paper](https://arxiv.org/pdf/2101.00939.pdf) | [Docs](https://crslab.readthedocs.io/en/latest/?badge=latest) -| [中文版](./README_CN.md) +# Quick Start 🚀 -**CRSLab** is an open-source toolkit for building Conversational Recommender System (CRS). It is developed based on -Python and PyTorch. CRSLab has the following highlights: - -- **Comprehensive benchmark models and datasets**: We have integrated commonly-used 6 datasets and 18 models, including graph neural network and pre-training models such as R-GCN, BERT and GPT-2. We have preprocessed these datasets to support these models, and release for downloading. -- **Extensive and standard evaluation protocols**: We support a series of widely-adopted evaluation protocols for testing and comparing different CRS. -- **General and extensible structure**: We design a general and extensible structure to unify various conversational recommendation datasets and models, in which we integrate various built-in interfaces and functions for quickly development. -- **Easy to get started**: We provide simple yet flexible configuration for new researchers to quickly start in our library. -- **Human-machine interaction interfaces**: We provide flexible human-machine interaction interfaces for researchers to conduct qualitative analysis. - -

- RecBole v0.1 architecture -
- Figure 1: The overall framework of CRSLab -

- - - - -- [Installation](#Installation) -- [Quick-Start](#Quick-Start) -- [Models](#Models) -- [Datasets](#Datasets) -- [Performance](#Performance) -- [Releases](#Releases) -- [Contributions](#Contributions) -- [Citing](#Citing) -- [Team](#Team) -- [License](#License) - - - -## Installation - -CRSLab works with the following operating systems: - -- Linux -- Windows 10 -- macOS X - -CRSLab requires Python version 3.6 or later. - -CRSLab requires torch version 1.4.0 or later. If you want to use CRSLab with GPU, please ensure that CUDA or CUDAToolkit version is 9.2 or later. Please use the combinations shown in this [Link](https://pytorch-geometric.com/whl/) to ensure the normal operation of PyTorch Geometric. - - - -### Install PyTorch - -Use PyTorch [Locally Installation](https://pytorch.org/get-started/locally/) or [Previous Versions Installation](https://pytorch.org/get-started/previous-versions/) commands to install PyTorch. For example, on Linux and Windows 10: - -```bash -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` - -If you want to use CRSLab with GPU, make sure the following command prints `True` after installation: - -```bash -$ python -c "import torch; print(torch.cuda.is_available())" ->>> True -``` - - - -### Install PyTorch Geometric - -Ensure that at least PyTorch 1.4.0 is installed: - -```bash -$ python -c "import torch; print(torch.__version__)" ->>> 1.6.0 -``` - -Find the CUDA version PyTorch was installed with: - -```bash -$ python -c "import torch; print(torch.version.cuda)" ->>> 10.1 -``` - -Install the relevant packages: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-geometric -``` - -where `${CUDA}` and `${TORCH}` should be replaced by your specific CUDA version (`cpu`, `cu92`, `cu101`, `cu102`, `cu110`) and PyTorch version (`1.4.0`, `1.5.0`, `1.6.0`, `1.7.0`) respectively. For example, for PyTorch 1.6.0 and CUDA 10.1, type: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-geometric -``` - - - -### Install CRSLab - -You can install from pip: - -```bash -pip install crslab -``` - -OR install from source: - -```bash -git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab -pip install -e . -``` - - - -## Quick-Start - -With the source code, you can use the provided script for initial usage of our library with cpu by default: - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml -``` - -The system will complete the data preprocessing, and training, validation, testing of each model in turn. Finally it will get the evaluation results of specified models. - -If you want to save pre-processed datasets and training results of models, you can use the following command: +## Modify your OpenAI API key 🔑 +Please open the config/iEvaLM/chatgpt/redial.yaml file and replace the **your_api_key** with your own OpenAI API key. +## Evaluate 🤔 +You can use two types of interaction: attribute-based question answering and free-form chit-chat. ```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system -``` - -In summary, there are following arguments in `run_crslab.py`: - -- `--config` or `-c`: relative path for configuration file(yaml). -- `--gpu` or `-g`: specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1). -- `--save_data` or `-sd`: save pre-processed dataset. -- `--restore_data` or `-rd`: restore pre-processed dataset from file. -- `--save_system` or `-ss`: save trained system. -- `--restore_system` or `-rs`: restore trained system from file. -- `--debug` or `-d`: use validation dataset to debug your system. -- `--interact` or `-i`: interact with your system instead of training. -- `--tensorboard` or `-tb`: enable tensorboard to monitor train performance. - - - -## Models - -In CRSLab, we unify the task description of conversational recommendation into three sub-tasks, namely recommendation (recommend user-preferred items), conversation (generate proper responses) and policy (select proper interactive action). The recommendation and conversation sub-tasks are the core of a CRS and have been studied in most of works. The policy sub-task is needed by recent works, by which the CRS can interact with users through purposeful strategy. -As the first release version, we have implemented 18 models in the four categories of CRS model, Recommendation model, Conversation model and Policy model. - -| Category | Model | Graph Neural Network? | Pre-training Model? | -| :------------------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: | -| CRS Model | [ReDial](https://arxiv.org/abs/1812.07617)
[KBRD](https://arxiv.org/abs/1908.05391)
[KGSF](https://arxiv.org/abs/2007.04032)
[TG-ReDial](https://arxiv.org/abs/2010.04125)
[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) | ×


×
× | ×
×
×

√ | -| Recommendation model | Popularity
[GRU4Rec](https://arxiv.org/abs/1511.06939)
[SASRec](https://arxiv.org/abs/1808.09781)
[TextCNN](https://arxiv.org/abs/1408.5882)
[R-GCN](https://arxiv.org/abs/1703.06103)
[BERT](https://arxiv.org/abs/1810.04805) | ×
×
×
×

× | ×
×
×
×
×
√ | -| Conversation model | [HERD](https://arxiv.org/abs/1507.04808)
[Transformer](https://arxiv.org/abs/1706.03762)
[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) | ×
×
× | ×
×
√ | -| Policy model | PMI
[MGCG](https://arxiv.org/abs/2005.03954)
[Conv-BERT](https://arxiv.org/abs/2010.04125)
[Topic-BERT](https://arxiv.org/abs/2010.04125)
[Profile-BERT](https://arxiv.org/abs/2010.04125) | ×
×
×
×
× | ×
×


√ | - -Among them, the four CRS models integrate the recommendation model and the conversation model to improve each other, while others only specify an individual task. - -For Recommendation model and Conversation model, we have respectively implemented the following commonly-used automatic evaluation metrics: - -| Category | Metrics | -| :--------------------: | :----------------------------------------------------------: | -| Recommendation Metrics | Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50} | -| Conversation Metrics | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} | -| Policy Metrics | Accuracy, Hit@{1,3,5} | - - - -## Datasets - -We have collected and preprocessed 6 commonly-used human-annotated datasets, and each dataset was matched with proper KGs as shown below: - -| Dataset | Dialogs | Utterances | Domains | Task Definition | Entity KG | Word KG | -| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: | -| [ReDial](https://redialdata.github.io/website/) | 10,006 | 182,150 | Movie | -- | DBpedia | ConceptNet | -| [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial) | 10,000 | 129,392 | Movie | Topic Guide | CN-DBpedia | HowNet | -| [GoRecDial](https://arxiv.org/abs/1909.03922) | 9,125 | 170,904 | Movie | Action Choice | DBpedia | ConceptNet | -| [DuRecDial](https://arxiv.org/abs/2005.03954) | 10,200 | 156,000 | Movie, Music | Goal Plan | CN-DBpedia | HowNet | -| [INSPIRED](https://github.com/sweetpeach/Inspired) | 1,001 | 35,811 | Movie | Social Strategy | DBpedia | ConceptNet | -| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802 | 91,209 | Movie, Book | Path Generate | DBpedia | ConceptNet | - - - -## Performance - -We have trained and test the integrated models on the TG-Redial dataset, which is split into training, validation and test sets using a ratio of 8:1:1. For each conversation, we start from the first utterance, and generate reply utterances or recommendations in turn by our model. We perform the evaluation on the three sub-tasks. - -### Recommendation Task - -| Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | -| SASRec | 0.000446 | 0.00134 | 0.0160 | 0.000446 | 0.000576 | 0.00114 | 0.000445 | 0.00075 | 0.00380 | -| TextCNN | 0.00267 | 0.0103 | 0.0236 | 0.00267 | 0.00434 | 0.00493 | 0.00267 | 0.00570 | 0.00860 | -| BERT | 0.00722 | 0.00490 | 0.0281 | 0.00722 | 0.0106 | 0.0124 | 0.00490 | 0.0147 | 0.0239 | -| KBRD | 0.00401 | 0.0254 | 0.0588 | 0.00401 | 0.00891 | 0.0103 | 0.00401 | 0.0127 | 0.0198 | -| KGSF | 0.00535 | **0.0285** | **0.0771** | 0.00535 | 0.0114 | **0.0135** | 0.00535 | **0.0154** | **0.0259** | -| TG-ReDial | **0.00793** | 0.0251 | 0.0524 | **0.00793** | **0.0122** | 0.0134 | **0.00793** | 0.0152 | 0.0211 | - - -### Conversation Task - -| Model | BLEU@1 | BLEU@2 | BLEU@3 | BLEU@4 | Dist@1 | Dist@2 | Dist@3 | Dist@4 | Average | Extreme | Greedy | PPL | -| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: | -| HERD | 0.120 | 0.0141 | 0.00136 | 0.000350 | 0.181 | 0.369 | 0.847 | 1.30 | 0.697 | 0.382 | 0.639 | 472 | -| Transformer | 0.266 | 0.0440 | 0.0145 | 0.00651 | 0.324 | 0.837 | 2.02 | 3.06 | 0.879 | 0.438 | 0.680 | 30.9 | -| GPT2 | 0.0858 | 0.0119 | 0.00377 | 0.0110 | **2.35** | **4.62** | **8.84** | **12.5** | 0.763 | 0.297 | 0.583 | 9.26 | -| KBRD | 0.267 | 0.0458 | 0.0134 | 0.00579 | 0.469 | 1.50 | 3.40 | 4.90 | 0.863 | 0.398 | 0.710 | 52.5 | -| KGSF | **0.383** | **0.115** | **0.0444** | **0.0200** | 0.340 | 0.910 | 3.50 | 6.20 | **0.888** | **0.477** | **0.767** | 50.1 | -| TG-ReDial | 0.125 | 0.0204 | 0.00354 | 0.000803 | 0.881 | 1.75 | 7.00 | 12.0 | 0.810 | 0.332 | 0.598 | **7.41** | - - -### Policy Task - -| Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | -| MGCG | 0.591 | 0.818 | 0.883 | 0.591 | 0.680 | 0.683 | 0.591 | 0.712 | 0.729 | -| Conv-BERT | 0.597 | 0.814 | 0.881 | 0.597 | 0.684 | 0.687 | 0.597 | 0.716 | 0.731 | -| Topic-BERT | 0.598 | 0.828 | 0.885 | 0.598 | 0.690 | 0.693 | 0.598 | 0.724 | 0.737 | -| TG-ReDial | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** | - -The above results were obtained from our CRSLab in preliminary experiments. However, these algorithms were implemented and tuned based on our understanding and experiences, which may not achieve their optimal performance. If you could yield a better result for some specific algorithm, please kindly let us know. We will update this table after the results are verified. - -## Releases - -| Releases | Date | Features | -| :------: | :-----------: | :----------: | -| v0.1.1 | 1 / 4 / 2021 | Basic CRSLab | -| v0.1.2 | 3 / 28 / 2021 | CRSLab | - - - -## Contributions - -Please let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/RUCAIBox/CRSLab/issues). - -We welcome all contributions from bug fixes to new features and extensions. - -We expect all contributions discussed in the issue tracker and going through PRs. - -We thank the nice contributions through PRs from [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang). - - - -## Citing - -If you find CRSLab useful for your research or development, please cite our [Paper](https://arxiv.org/pdf/2101.00939.pdf): - -``` -@article{crslab, - title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System}, - author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen}, - year={2021}, - journal={arXiv preprint arXiv:2101.00939} -} -``` - - - -## Team - -**CRSLab** was developed and maintained by [AI Box](http://aibox.ruc.edu.cn/) group in RUC. - - - -## License - -**CRSLab** uses [MIT License](./LICENSE). - +bash chatgpt_chat.sh # free-form chit-chat +bash chatgpt_ask.sh # attribute-based question answering +``` \ No newline at end of file diff --git a/README_CN.md b/README_CN.md deleted file mode 100644 index 8290a4a..0000000 --- a/README_CN.md +++ /dev/null @@ -1,290 +0,0 @@ -# CRSLab - -[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab) -[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases) -[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) -[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939) -[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest) - -[论文](https://arxiv.org/pdf/2101.00939.pdf) | [文档](https://crslab.readthedocs.io/en/latest/?badge=latest) -| [English Version](./README.md) - -**CRSLab** 是一个用于构建对话推荐系统(CRS)的开源工具包,其基于 PyTorch 实现、主要面向研究者使用,并具有如下特色: - -- **全面的基准模型和数据集**:我们集成了常用的 6 个数据集和 18 个模型,包括基于图神经网络和预训练模型,比如 GCN,BERT 和 GPT-2;我们还对数据集进行相关处理以支持这些模型,并提供预处理后的版本供大家下载。 -- **大规模的标准评测**:我们支持一系列被广泛认可的评估方式来测试和比较不同的 CRS。 -- **通用和可扩展的结构**:我们设计了通用和可扩展的结构来统一各种对话推荐数据集和模型,并集成了多种内置接口和函数以便于快速开发。 -- **便捷的使用方法**:我们为新手提供了简单而灵活的配置,方便其快速启动集成在 CRSLab 中的模型。 -- **人性化的人机交互接口**:我们提供了人性化的人机交互界面,以供研究者对比和测试不同的模型系统。 - -

- RecBole v0.1 architecture -
- 图片: CRSLab 的总体架构 -

- - - - -- [安装](#安装) -- [快速上手](#快速上手) -- [模型](#模型) -- [数据集](#数据集) -- [评测结果](#评测结果) -- [发行版本](#发行版本) -- [贡献](#贡献) -- [引用](#引用) -- [项目团队](#项目团队) -- [免责声明](#免责声明) - - - -## 安装 - -CRSLab 可以在以下几种系统上运行: - -- Linux -- Windows 10 -- macOS X - -CRSLab 需要在 Python 3.6 或更高的环境下运行。 - -CRSLab 要求 torch 版本在 1.4.0 及以上,如果你想在 GPU 上运行 CRSLab,请确保你的 CUDA 版本或者 CUDAToolkit 版本在 9.2 及以上。为保证 PyTorch Geometric 库的正常运行,请使用[链接](https://pytorch-geometric.com/whl/)所示的安装方式。 - - - -### 安装 PyTorch - -使用 PyTorch [本地安装](https://pytorch.org/get-started/locally/)命令或者[先前版本安装](https://pytorch.org/get-started/previous-versions/)命令安装 PyTorch,比如在 Linux 和 Windows 下: - -```bash -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` - -安装完成后,如果你想在 GPU 上运行 CRSLab,请确保如下命令输出`True`: - -```bash -$ python -c "import torch; print(torch.cuda.is_available())" ->>> True -``` - - - -### 安装 PyTorch Geometric - -确保安装的 PyTorch 版本至少为 1.4.0: - -```bash -$ python -c "import torch; print(torch.__version__)" ->>> 1.6.0 -``` - -找到安装好的 PyTorch 对应的 CUDA 版本: - -```bash -$ python -c "import torch; print(torch.version.cuda)" ->>> 10.1 -``` - -安装相关的包: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-geometric -``` - -其中`${CUDA}`和`${TORCH}`应使用确定的 CUDA 版本(`cpu`,`cu92`,`cu101`,`cu102`,`cu110`)和 PyTorch 版本(`1.4.0`,`1.5.0`,`1.6.0`,`1.7.0`)来分别替换。比如,对于 PyTorch 1.6.0 和 CUDA 10.1,输入: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-geometric -``` - - - -### 安装 CRSLab - -你可以通过 pip 来安装: - -```bash -pip install crslab -``` - -也可以通过源文件进行进行安装: - -```bash -git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab -pip install -e . -``` - - - -## 快速上手 - -从 GitHub 下载 CRSLab 后,可以使用提供的脚本快速运行和测试,默认使用CPU: - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml -``` - -系统将依次完成数据的预处理,以及各模块的训练、验证和测试,并得到指定的模型评测结果。 - -如果你希望保存数据预处理结果与模型训练结果,可以使用如下命令: - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system -``` - -总的来说,`run_crslab.py`有如下参数可供调用: - -- `--config` 或 `-c`:配置文件的相对路径,以指定运行的模型与数据集。 -- `--gpu` or `-g`:指定 GPU id,支持多 GPU,默认使用 CPU(-1)。 -- `--save_data` 或 `-sd`:保存预处理的数据。 -- `--restore_data` 或 `-rd`:从文件读取预处理的数据。 -- `--save_system` 或 `-ss`:保存训练好的 CRS 系统。 -- `--restore_system` 或 `-rs`:从文件载入提前训练好的系统。 -- `--debug` 或 `-d`:用验证集代替训练集以方便调试。 -- `--interact` 或 `-i`:与你的系统进行对话交互,而非进行训练。 -- `--tensorboard` or `-tb`:使用 tensorboardX 组件来监测训练表现。 - - - -## 模型 - -在第一个发行版中,我们实现了 4 类共 18 个模型。这里我们将对话推荐任务主要拆分成三个任务:推荐任务(生成推荐的商品),对话任务(生成对话的回复)和策略任务(规划对话推荐的策略)。其中所有的对话推荐系统都具有对话和推荐任务,他们是对话推荐系统的核心功能。而策略任务是一个辅助任务,其致力于更好的控制对话推荐系统,在不同的模型中的实现也可能不同(如 TG-ReDial 采用一个主题预测模型,DuRecDial 中采用一个对话规划模型等): - - - -| 类别 | 模型 | Graph Neural Network? | Pre-training Model? | -| :------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: | -| CRS 模型 | [ReDial](https://arxiv.org/abs/1812.07617)
[KBRD](https://arxiv.org/abs/1908.05391)
[KGSF](https://arxiv.org/abs/2007.04032)
[TG-ReDial](https://arxiv.org/abs/2010.04125)
[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) | ×


×
× | ×
×
×

√ | -| 推荐模型 | Popularity
[GRU4Rec](https://arxiv.org/abs/1511.06939)
[SASRec](https://arxiv.org/abs/1808.09781)
[TextCNN](https://arxiv.org/abs/1408.5882)
[R-GCN](https://arxiv.org/abs/1703.06103)
[BERT](https://arxiv.org/abs/1810.04805) | ×
×
×
×

× | ×
×
×
×
×
√ | -| 对话模型 | [HERD](https://arxiv.org/abs/1507.04808)
[Transformer](https://arxiv.org/abs/1706.03762)
[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) | ×
×
× | ×
×
√ | -| 策略模型 | PMI
[MGCG](https://arxiv.org/abs/2005.03954)
[Conv-BERT](https://arxiv.org/abs/2010.04125)
[Topic-BERT](https://arxiv.org/abs/2010.04125)
[Profile-BERT](https://arxiv.org/abs/2010.04125) | ×
×
×
×
× | ×
×


√ | - - -其中,CRS 模型是指直接融合推荐模型和对话模型,以相互增强彼此的效果,故其内部往往已经包含了推荐、对话和策略模型。其他如推荐模型、对话模型、策略模型往往只关注以上任务中的某一个。 - -我们对于这几类模型,我们还分别实现了如下的自动评测指标模块: - -| 类别 | 指标 | -| :------: | :----------------------------------------------------------: | -| 推荐指标 | Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50} | -| 对话指标 | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} | -| 策略指标 | Accuracy, Hit@{1,3,5} | - - - - - -## 数据集 - -我们收集了 6 个常用的人工标注数据集,并对它们进行了预处理(包括引入外部知识图谱),以融入统一的 CRS 任务中。如下为相关数据集的统计数据: - -| Dataset | Dialogs | Utterances | Domains | Task Definition | Entity KG | Word KG | -| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: | -| [ReDial](https://redialdata.github.io/website/) | 10,006 | 182,150 | Movie | -- | DBpedia | ConceptNet | -| [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial) | 10,000 | 129,392 | Movie | Topic Guide | CN-DBpedia | HowNet | -| [GoRecDial](https://arxiv.org/abs/1909.03922) | 9,125 | 170,904 | Movie | Action Choice | DBpedia | ConceptNet | -| [DuRecDial](https://arxiv.org/abs/2005.03954) | 10,200 | 156,000 | Movie, Music | Goal Plan | CN-DBpedia | HowNet | -| [INSPIRED](https://github.com/sweetpeach/Inspired) | 1,001 | 35,811 | Movie | Social Strategy | DBpedia | ConceptNet | -| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802 | 91,209 | Movie, Book | Path Generate | DBpedia | ConceptNet | - - - -## 评测结果 - -我们在 TG-ReDial 数据集上对模型进行了训练和测试,这里我们将数据集按照 8:1:1 切分。其中对于每条数据,我们从对话的第一轮开始,一轮一轮的进行推荐、策略生成、回复生成任务。下表记录了相关的评测结果。 - -### 推荐任务 - -| 模型 | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | -| SASRec | 0.000446 | 0.00134 | 0.0160 | 0.000446 | 0.000576 | 0.00114 | 0.000445 | 0.00075 | 0.00380 | -| TextCNN | 0.00267 | 0.0103 | 0.0236 | 0.00267 | 0.00434 | 0.00493 | 0.00267 | 0.00570 | 0.00860 | -| BERT | 0.00722 | 0.00490 | 0.0281 | 0.00722 | 0.0106 | 0.0124 | 0.00490 | 0.0147 | 0.0239 | -| KBRD | 0.00401 | 0.0254 | 0.0588 | 0.00401 | 0.00891 | 0.0103 | 0.00401 | 0.0127 | 0.0198 | -| KGSF | 0.00535 | **0.0285** | **0.0771** | 0.00535 | 0.0114 | **0.0135** | 0.00535 | **0.0154** | **0.0259** | -| TG-ReDial | **0.00793** | 0.0251 | 0.0524 | **0.00793** | **0.0122** | 0.0134 | **0.00793** | 0.0152 | 0.0211 | - - - -### 对话任务 - -| 模型 | BLEU@1 | BLEU@2 | BLEU@3 | BLEU@4 | Dist@1 | Dist@2 | Dist@3 | Dist@4 | Average | Extreme | Greedy | PPL | -| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: | -| HERD | 0.120 | 0.0141 | 0.00136 | 0.000350 | 0.181 | 0.369 | 0.847 | 1.30 | 0.697 | 0.382 | 0.639 | 472 | -| Transformer | 0.266 | 0.0440 | 0.0145 | 0.00651 | 0.324 | 0.837 | 2.02 | 3.06 | 0.879 | 0.438 | 0.680 | 30.9 | -| GPT2 | 0.0858 | 0.0119 | 0.00377 | 0.0110 | **2.35** | **4.62** | **8.84** | **12.5** | 0.763 | 0.297 | 0.583 | 9.26 | -| KBRD | 0.267 | 0.0458 | 0.0134 | 0.00579 | 0.469 | 1.50 | 3.40 | 4.90 | 0.863 | 0.398 | 0.710 | 52.5 | -| KGSF | **0.383** | **0.115** | **0.0444** | **0.0200** | 0.340 | 0.910 | 3.50 | 6.20 | **0.888** | **0.477** | **0.767** | 50.1 | -| TG-ReDial | 0.125 | 0.0204 | 0.00354 | 0.000803 | 0.881 | 1.75 | 7.00 | 12.0 | 0.810 | 0.332 | 0.598 | **7.41** | - - - -### 策略任务 - -| 模型 | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | -| MGCG | 0.591 | 0.818 | 0.883 | 0.591 | 0.680 | 0.683 | 0.591 | 0.712 | 0.729 | -| Conv-BERT | 0.597 | 0.814 | 0.881 | 0.597 | 0.684 | 0.687 | 0.597 | 0.716 | 0.731 | -| Topic-BERT | 0.598 | 0.828 | 0.885 | 0.598 | 0.690 | 0.693 | 0.598 | 0.724 | 0.737 | -| TG-ReDial | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** | - -上述结果是我们使用 CRSLab 进行实验得到的。然而,这些算法是根据我们的经验和理解来实现和调参的,可能还没有达到它们的最佳性能。如果您能在某个具体算法上得到更好的结果,请告知我们。验证结果后,我们会更新该表。 - -## 发行版本 - -| 版本号 | 发行日期 | 特性 | -| :----: | :-----------: | :----------: | -| v0.1.1 | 1 / 4 / 2021 | Basic CRSLab | -| v0.1.2 | 3 / 28 / 2021 | CRSLab | - - - -## 贡献 - -如果您遇到错误或有任何建议,请通过 [Issue](https://github.com/RUCAIBox/CRSLab/issues) 进行反馈 - -我们欢迎关于修复错误、添加新特性的任何贡献。 - -如果想贡献代码,请先在 Issue 中提出问题,然后再提 PR。 - -我们感谢 [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang) 通过 PR 为项目贡献的新特性。 - - - -## 引用 - -如果你觉得 CRSLab 对你的科研工作有帮助,请引用我们的[论文](https://arxiv.org/pdf/2101.00939.pdf): - -``` -@article{crslab, - title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System}, - author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen}, - year={2021}, - journal={arXiv preprint arXiv:2101.00939} -} -``` - - - -## 项目团队 - -**CRSLab** 由中国人民大学 [AI Box](http://aibox.ruc.edu.cn/) 小组开发和维护。 - - - -## 免责声明 - -**CRSLab** 基于 [MIT License](./LICENSE) 进行开发,本项目的所有数据和代码只能被用于学术目的。 diff --git a/chatgpt_ask.sh b/chatgpt_ask.sh new file mode 100644 index 0000000..89f3d35 --- /dev/null +++ b/chatgpt_ask.sh @@ -0,0 +1 @@ +python run_crslab.py --config config/iEvaLM/chatgpt/redial.yaml --mode ask \ No newline at end of file diff --git a/chatgpt_chat.sh b/chatgpt_chat.sh new file mode 100644 index 0000000..4e8478c --- /dev/null +++ b/chatgpt_chat.sh @@ -0,0 +1 @@ +python run_crslab.py --config config/iEvaLM/chatgpt/redial.yaml --mode chat \ No newline at end of file diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 92a5329..5762297 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 # optim diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index ea155c4..7dd0a6d 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 # optim diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index d96e8d6..bf93096 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 # optim diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 3a89ac8..d6db3d0 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 # optim diff --git a/config/conversation/transformer/durecdial.yaml b/config/conversation/transformer/durecdial.yaml index c61973c..a9f92fe 100644 --- a/config/conversation/transformer/durecdial.yaml +++ b/config/conversation/transformer/durecdial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/gorecdial.yaml b/config/conversation/transformer/gorecdial.yaml index c05ddb4..fd578c5 100644 --- a/config/conversation/transformer/gorecdial.yaml +++ b/config/conversation/transformer/gorecdial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/inspired.yaml b/config/conversation/transformer/inspired.yaml index 90b86d8..fc6ea05 100644 --- a/config/conversation/transformer/inspired.yaml +++ b/config/conversation/transformer/inspired.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/opendialkg.yaml b/config/conversation/transformer/opendialkg.yaml index 2971704..208adce 100644 --- a/config/conversation/transformer/opendialkg.yaml +++ b/config/conversation/transformer/opendialkg.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/redial.yaml b/config/conversation/transformer/redial.yaml index 25f2ab4..ae1c7c6 100644 --- a/config/conversation/transformer/redial.yaml +++ b/config/conversation/transformer/redial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/crs/kbrd/durecdial.yaml b/config/crs/kbrd/durecdial.yaml index 4a115f2..fb62d7e 100644 --- a/config/crs/kbrd/durecdial.yaml +++ b/config/crs/kbrd/durecdial.yaml @@ -4,7 +4,7 @@ tokenize: jieba # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.1 +scale: 1 # model model: KBRD token_emb_dim: 300 diff --git a/config/crs/kbrd/gorecdial.yaml b/config/crs/kbrd/gorecdial.yaml index be78096..98360a5 100644 --- a/config/crs/kbrd/gorecdial.yaml +++ b/config/crs/kbrd/gorecdial.yaml @@ -4,7 +4,7 @@ tokenize: nltk # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model model: KBRD token_emb_dim: 300 diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..481eb1f 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..b9b1ad1 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..5d3df93 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..9b235da 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: KGSF token_emb_dim: 300 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..981dbcf 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: KGSF token_emb_dim: 300 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..91f38c7 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -5,7 +5,7 @@ embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: NTRD token_emb_dim: 300 diff --git a/config/crs/redial/opendialkg.yaml b/config/crs/redial/opendialkg.yaml index 061616b..7cf0637 100644 --- a/config/crs/redial/opendialkg.yaml +++ b/config/crs/redial/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/redial/redial.yaml b/config/crs/redial/redial.yaml index c6e848d..2cf569f 100644 --- a/config/crs/redial/redial.yaml +++ b/config/crs/redial/redial.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/redial/tgredial.yaml b/config/crs/redial/tgredial.yaml index a87d3c5..fa388f9 100644 --- a/config/crs/redial/tgredial.yaml +++ b/config/crs/redial/tgredial.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 08a96aa..0bd4e9e 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -7,7 +7,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 74382ff..f766def 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -7,7 +7,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index bcfb217..563bbcc 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -7,7 +7,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 8e983a0..60943d8 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -7,7 +7,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv diff --git a/config/iEvaLM/chatgpt/redial.yaml b/config/iEvaLM/chatgpt/redial.yaml new file mode 100644 index 0000000..01367fc --- /dev/null +++ b/config/iEvaLM/chatgpt/redial.yaml @@ -0,0 +1,11 @@ +model: ChatGPT +tokenize: nltk +dataset: ReDial +api_key: your_api_key +turn_num: 5 +rec: + batch_size: 1 +conv: + batch_size: 1 +cache_item: + batch_size: 1000 \ No newline at end of file diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 87bb5e6..f6ba96c 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model policy_model: PMI # optim diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index 0d4250a..051e8d1 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT # optim diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 22ff335..cdbb30e 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT # optim diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 4b59696..ae9f12f 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT # optim diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index be5fa53..48d664c 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT # optim diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 94a5f6a..aa73472 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 0d80c59..5d48dd5 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.1 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index b4900b9..3aebaf8 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 7b707e7..b4f9c75 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index 3131e0a..ca8759a 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index bf77cd6..e6334b0 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index ebaf2c9..3ff691b 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index b0cbec9..27ad004 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 66c9ef7..95b9247 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index 15ba15e..2a32693 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 243a646..98b6dfe 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index ba4c02d..0efb55e 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index add69ec..2dea48e 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/durecdial.yaml b/config/recommendation/textcnn/durecdial.yaml index 3040132..c244237 100644 --- a/config/recommendation/textcnn/durecdial.yaml +++ b/config/recommendation/textcnn/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/gorecdial.yaml b/config/recommendation/textcnn/gorecdial.yaml index 6791031..f2fc0d5 100644 --- a/config/recommendation/textcnn/gorecdial.yaml +++ b/config/recommendation/textcnn/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/opendialkg.yaml b/config/recommendation/textcnn/opendialkg.yaml index 88eba55..5f9972b 100644 --- a/config/recommendation/textcnn/opendialkg.yaml +++ b/config/recommendation/textcnn/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/redial.yaml b/config/recommendation/textcnn/redial.yaml index 0466f51..6c43c0a 100644 --- a/config/recommendation/textcnn/redial.yaml +++ b/config/recommendation/textcnn/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/crslab/config/config.py b/crslab/config/config.py index 3ba6fe7..1bcd061 100644 --- a/crslab/config/config.py +++ b/crslab/config/config.py @@ -38,7 +38,7 @@ def __init__(self, config_file, gpu='-1', debug=False): # gpu os.environ['CUDA_VISIBLE_DEVICES'] = gpu if gpu != '-1': - self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] + self.opt['gpu'] = [i for i in gpu.split(',')] else: self.opt['gpu'] = [-1] # dataset diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 33bea19..299d4e5 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -66,7 +66,8 @@ 'ProfileBERT': TGReDialDataLoader, 'MGCG': TGReDialDataLoader, 'PMI': TGReDialDataLoader, - 'NTRD': NTRDDataLoader + 'NTRD': NTRDDataLoader, + 'ChatGPT': ChatGPTDataLoader } diff --git a/crslab/data/dataloader/__init__.py b/crslab/data/dataloader/__init__.py index 7b4ce12..dc83a89 100644 --- a/crslab/data/dataloader/__init__.py +++ b/crslab/data/dataloader/__init__.py @@ -5,3 +5,4 @@ from .redial import ReDialDataLoader from .tgredial import TGReDialDataLoader from .ntrd import NTRDDataLoader +from .chatgpt import ChatGPTDataLoader diff --git a/crslab/data/dataloader/chatgpt.py b/crslab/data/dataloader/chatgpt.py new file mode 100644 index 0000000..41dd82b --- /dev/null +++ b/crslab/data/dataloader/chatgpt.py @@ -0,0 +1,60 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +from tqdm import tqdm + +from crslab.data.dataloader.base import BaseDataLoader + +class ChatGPTDataLoader(BaseDataLoader): + + def __init__(self, opt, dataset, vocab): + super().__init__(opt, dataset) + + def process_fn(self): + augment_dataset = [] + for conv_dict in tqdm(self.dataset): + if conv_dict['role'] == 'Recommender' and len(conv_dict['items']) > 0: + augment_conv_dict = { + 'dialog_id': conv_dict['dialog_id'], + 'role': conv_dict['role'], + 'entity': conv_dict['context_entities'], + 'context': conv_dict['context'], + 'item': conv_dict['items'] + } + augment_dataset.append(augment_conv_dict) + return augment_dataset + + def batchify(self, batch): + batch_dialog_id = [] + batch_role = [] + batch_context = [] + batch_movies = [] + batch_entities = [] + + for conv_dict in batch: + batch_dialog_id.append(conv_dict['dialog_id']) + batch_role.append(conv_dict['role']) + batch_context.append(conv_dict['context']) + batch_movies.append(conv_dict['item']) + batch_entities.append(conv_dict['entity']) + + return { + 'dialog_id': batch_dialog_id, + 'role': batch_role, + 'context': batch_context, + 'item': batch_movies, + 'entity': batch_entities + } + + def rec_process_fn(self): + return self.process_fn() + + def rec_batchify(self, batch): + return self.rec_batchify(batch) + + def conv_process_fn(self): + return self.process_fn() + + def conv_batchify(self, batch): + return self.batchify(batch) \ No newline at end of file diff --git a/crslab/data/dataloader/kbrd.py b/crslab/data/dataloader/kbrd.py index 2720c5d..182044d 100644 --- a/crslab/data/dataloader/kbrd.py +++ b/crslab/data/dataloader/kbrd.py @@ -58,20 +58,27 @@ def rec_process_fn(self): for conv_dict in tqdm(self.dataset): if conv_dict['role'] == 'Recommender': for movie in conv_dict['items']: - augment_conv_dict = {'context_entities': conv_dict['context_entities'], 'item': movie} + augment_conv_dict = { + 'role': conv_dict['role'], + 'context': conv_dict['context'], + 'item': movie + } augment_dataset.append(augment_conv_dict) return augment_dataset def rec_batchify(self, batch): - batch_context_entities = [] + batch_role = [] + batch_context = [] batch_movies = [] for conv_dict in batch: - batch_context_entities.append(conv_dict['context_entities']) + batch_role.append(conv_dict['role']) + batch_context.append(conv_dict['context']) batch_movies.append(conv_dict['item']) return { - "context_entities": batch_context_entities, - "item": torch.tensor(batch_movies, dtype=torch.long) + "role": batch_role, + 'context': batch_context, + "item": batch_movies } def conv_process_fn(self, *args, **kwargs): diff --git a/crslab/data/dataset/base.py b/crslab/data/dataset/base.py index 6befff5..a04c672 100644 --- a/crslab/data/dataset/base.py +++ b/crslab/data/dataset/base.py @@ -14,6 +14,8 @@ import numpy as np from loguru import logger +import json + from crslab.download import build @@ -134,6 +136,14 @@ def _data_preprocess(self, train_data, valid_data, test_data): """ pass + + @abstractmethod + def get_attr_list(self): + """ + Returns: + (list of str): attributes + """ + pass def _load_from_restore(self, file_name="all_data.pkl"): """Restore saved dataset. diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..9b0db5a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -68,7 +68,7 @@ def __init__(self, opt, tokenize, restore=False, save=False): resource = resources[tokenize] self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..6bb858f 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -14,7 +14,7 @@ 'jieba': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', 'durecdial_jieba.zip', 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', ), @@ -30,7 +30,7 @@ 'bert': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', 'durecdial_bert.zip', '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' ), @@ -49,7 +49,7 @@ 'gpt2': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', 'durecdial_gpt2.zip', 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' ), diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..fe3e288 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -68,7 +68,7 @@ def __init__(self, opt, tokenize, restore=False, save=False): resource = resources[tokenize] self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..030202e 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -12,11 +12,11 @@ resources = { 'nltk': { - 'version': '0.3', + 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESIqjwAg0ItAu7WGfukIt3cBXjzi7AZ9L_lcbFT1aS1qYQ?download=1', 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + '58cd368f8f83c0c8555becc314a0017990545f71aefb7e93a52581c97d1b8e9b', ), 'special_token_idx': { 'pad': 0, @@ -29,11 +29,11 @@ }, }, 'bert': { - 'version': '0.3', + 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed1HT8gzvRpDosVT83BEj5QBnzKpjR3Zbf5u49yyWP-k6Q?download=1', 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' + '4fa10c3fe8ba538af0f393c99892739fcb376d832616aa7028334c594b3fec10' ), 'special_token_idx': { 'pad': 0, @@ -48,11 +48,11 @@ } }, 'gpt2': { - 'version': '0.3', + 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EUJOHmX8v79DkZMq0x5r9d4B0UJlfw85v-VdciwKfAhpng?download=1', 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' + '44a15637e014b2e6628102ff654e1aef7ec1cbfa34b7ada1a03f294f72ddd4b1' ), 'special_token_idx': { 'pad': 0, diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..c44f0d7 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -68,7 +68,7 @@ def __init__(self, opt, tokenize, restore=False, save=False): resource = resources[tokenize] self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..504a760 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -14,7 +14,7 @@ 'nltk': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', 'inspired_nltk.zip', '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', ), @@ -30,7 +30,7 @@ 'bert': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', 'inspired_bert.zip', '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' ), @@ -48,9 +48,9 @@ 'gpt2': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' + '261ad7e5325258d5cb8ffef0751925a58270fb6d9f17490f8552f6b86ef1eed2' ), 'special_token_idx': { 'pad': 0, diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..102fd8b 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -69,7 +69,7 @@ def __init__(self, opt, tokenize, restore=False, save=False): resource = resources[tokenize] self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -271,3 +271,7 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def get_attr_list(self): + attr_list = ['genre', 'actor', 'director', 'writer'] + return attr_list \ No newline at end of file diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..9f7fb62 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -14,9 +14,9 @@ 'nltk': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f' ), 'special_token_idx': { 'pad': 0, @@ -30,7 +30,7 @@ 'bert': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', 'opendialkg_bert.zip', '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' ), @@ -48,7 +48,7 @@ 'gpt2': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', 'opendialkg_gpt2.zip', 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' ), diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..6ceb036 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE: +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" ReDial ====== @@ -19,6 +24,7 @@ """ import json +import re import os from collections import defaultdict from copy import copy @@ -69,21 +75,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): resource = resources[tokenize] self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + dpath = os.path.join(DATASET_PATH, opt['dataset']) super().__init__(opt, dpath, resource, restore, save) def _load_data(self): train_data, valid_data, test_data = self._load_raw_data() - self._load_vocab() + # self._load_vocab() self._load_other_data() vocab = { - 'tok2ind': self.tok2ind, - 'ind2tok': self.ind2tok, + # 'tok2ind': self.tok2ind, + # 'ind2tok': self.ind2tok, 'entity2id': self.entity2id, 'id2entity': self.id2entity, 'word2id': self.word2id, - 'vocab_size': len(self.tok2ind), + # 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, } @@ -93,14 +99,28 @@ def _load_data(self): def _load_raw_data(self): # load train/valid/test data - with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: - train_data = json.load(f) + train_data = [] + valid_data = [] + test_data = [] + + with open(os.path.join(self.dpath, 'train_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + train_data.append(data) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: - valid_data = json.load(f) + with open(os.path.join(self.dpath, 'valid_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + valid_data.append(data) + valid_data.append(data) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: - test_data = json.load(f) + with open(os.path.join(self.dpath, 'test_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + test_data.append(data) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") return train_data, valid_data, test_data @@ -120,7 +140,7 @@ def _load_other_data(self): self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8')) + self.entity_kg = json.load(open(os.path.join(self.dpath, 'kg.json'), 'r', encoding='utf-8')) logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") @@ -132,6 +152,12 @@ def _load_other_data(self): self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]") + with open(os.path.join(self.dpath, 'id2info.json'), 'r', encoding='utf-8') as f: + id2info = json.load(f) + self.id2name = {} + for id, info in id2info.items(): + self.id2name[id] = info['name'] + def _data_preprocess(self, train_data, valid_data, test_data): processed_train_data = self._raw_data_process(train_data) @@ -145,49 +171,69 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._merge_conv_data(conversation["dialog"]) for conversation in tqdm(raw_data)] + augmented_convs = [self._merge_conv_data(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) return augmented_conv_dicts def _merge_conv_data(self, dialog): + movie_pattern = re.compile(r'@\d+') augmented_convs = [] last_role = None - for utt in dialog: - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + user_id = dialog['initiatorWorkerId'] + recommender_id = dialog['respondentWorkerId'] + conversation_id = dialog['conversationId'] + + for utt in dialog["messages"]: + # text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + turn_id = utt['turn_id'] + dialog_turn_id = str(conversation_id) + '_' + str(turn_id) + + text = utt['text'] + for pattern in re.findall(movie_pattern, text): + if pattern.strip('@') in self.id2name: + text = text.replace(pattern, self.id2name[pattern.strip('@')]) + movie_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] - - if utt["role"] == last_role: - augmented_convs[-1]["text"] += text_token_ids - augmented_convs[-1]["movie"] += movie_ids + + role_id = utt['senderWorkerId'] + if role_id == recommender_id: + role = 'Recommender' + elif role_id == user_id: + role = 'User' + + if role == last_role: + augmented_convs[-1]["text"] += text + augmented_convs[-1]["item"] += movie_ids augmented_convs[-1]["entity"] += entity_ids augmented_convs[-1]["word"] += word_ids else: augmented_convs.append({ - "role": utt["role"], - "text": text_token_ids, + "dialog_id": dialog_turn_id, + "role": role, + "text": text, "entity": entity_ids, - "movie": movie_ids, + "item": movie_ids, "word": word_ids }) - last_role = utt["role"] + last_role = role return augmented_convs def _augment_and_add(self, raw_conv_dict): augmented_conv_dicts = [] - context_tokens, context_entities, context_words, context_items = [], [], [], [] + context, context_entities, context_words, context_items = [], [], [], [] entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): - text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"] - if len(context_tokens) > 0: + text, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context) > 0: conv_dict = { + "dialog_id": conv['dialog_id'], "role": conv['role'], - "context_tokens": copy(context_tokens), - "response": text_tokens, + "context": copy(context), + "response": text, "context_entities": copy(context_entities), "context_words": copy(context_words), "context_items": copy(context_items), @@ -195,7 +241,7 @@ def _augment_and_add(self, raw_conv_dict): } augmented_conv_dicts.append(conv_dict) - context_tokens.append(text_tokens) + context.append(text) context_items += movies for entity in entities + movies: if entity not in entity_set: @@ -213,7 +259,7 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -266,3 +312,36 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def get_attr_list(self): + attr_list = ['genre', 'star', 'director'] + return attr_list + + def get_ask_instruction(self): + ask_instruction = '''To recommend me items that I will accept, you can choose one of the following options. +A: ask my preference for genre +B: ask my preference for actor +C: ask my preference for director +D: I can directly give recommendations +Please enter the option character. Please only response a character.''' + option2attr = { + 'A': 'genre', + 'B': 'star', + 'C': 'director', + 'D': 'recommend' + } + option2template = { + 'A': 'Which genre do you like?', + 'B': 'Which star do you like?', + 'C': 'Which director do you like?', + } + rec_instruction = 'Please give me 10 recommendations according to my preference (Format: no. title. No other things except the item list in your response). You can recommend mentioned items in our dialog.' + + ask_instruction_dict = { + 'ask_instruction': ask_instruction, + 'option2attr': option2attr, + 'option2template': option2template, + 'rec_instruction': rec_instruction + } + + return ask_instruction_dict \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..551e7ab 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -14,7 +14,7 @@ 'nltk': { 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', 'redial_nltk.zip', '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', ), @@ -30,7 +30,7 @@ 'bert': { 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', 'redial_bert.zip', 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', ), @@ -48,7 +48,7 @@ 'gpt2': { 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', 'redial_gpt2.zip', '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', ), diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..b46e73b 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -14,7 +14,7 @@ 'pkuseg': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', 'tgredial_pkuseg.zip', '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', ), @@ -31,7 +31,7 @@ 'bert': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', 'tgredial_bert.zip', 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' ), @@ -50,7 +50,7 @@ 'gpt2': { 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', 'tgredial_gpt2.zip', '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' ), diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..e23f1c4 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -73,7 +73,7 @@ def __init__(self, opt, tokenize, restore=False, save=False): self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + dpath = os.path.join(DATASET_PATH, 'tgredial') self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/evaluator/ask.py b/crslab/evaluator/ask.py new file mode 100644 index 0000000..c6097ea --- /dev/null +++ b/crslab/evaluator/ask.py @@ -0,0 +1,297 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import argparse +import copy +import json +import os +import re +import random +import time +import typing +import warnings +import tiktoken + +import numpy as np +import openai +import nltk + +from copy import copy +from tqdm import tqdm +from loguru import logger +from thefuzz import fuzz +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import get_entity, get_instruction + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +class Ask(): + + def __init__(self, turn_num, crs_model, dataset, ask_instruction_dict) -> None: + self.turn_num = turn_num + self.crs_model = crs_model + self.ask_instruction_dict = ask_instruction_dict + self.dataset_path = os.path.join(DATASET_PATH, dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, dataset, 'embed') + item_emb_list = [] + id2item_id = [] + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entity = {} + for k, v in self.entity2id.items(): + self.id2entity[int(v)] = k + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + self.entityid2id = {} + for id, entityid in self.id2entityid.items(): + self.entityid2id[entityid] = id + + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + def ask(self, batch, turn_num): + + ask_instruction = self.ask_instruction_dict['ask_instruction'] + option2attr = self.ask_instruction_dict['option2attr'] + option2template = self.ask_instruction_dict['option2template'] + rec_instruction = self.ask_instruction_dict['rec_instruction'] + recommendation_template = "I would recommend the following items:\n{}" + + contexts_batch = batch['context'] + items_batch = batch['item'] + entity_batch = batch['entity'] + + for context, items, entities in zip(contexts_batch, items_batch, entity_batch): + + context_list = [] + + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + role_str = 'user' + else: + role_str = 'assistant' + context_list.append({ + 'role': role_str, + 'content': text + }) + + rec_success = False + option2index = { + 'A': 0, + 'B': 1, + 'C': 2, + 'D': 3, + 'E': 4 + } + + options = list(option2attr.keys()) + state = [0 for _ in range(len(options))] + + for i in range(0, turn_num): + # seeker + + context_list.append({ + 'role': 'user', + 'content': ask_instruction + }) + + context.append(ask_instruction) + batch['context'] = [copy(context)] + + # recommender + # choose option + # options (list of str): available options, generate one of them + gen_inputs, recommender_text = self.crs_model.converse(batch) + recommender_choose = self.crs_model.choose(gen_inputs, options, state, batch) + selected_option = recommender_choose + + if selected_option == options[-1]: # choose to rec + # recommender + _, item_rank_arr = self.crs_model.recommend(batch) + pred_items = item_rank_arr[0] + + rec_items_str = '' + for j, rec_item in enumerate(pred_items): + rec_items_str += f"{j + 1}: {self.id2entity[rec_item]}\n" + recommender_text = recommendation_template.format(rec_items_str) + + # judge whether success + for rec_label in items: + if rec_label in pred_items: + rec_success = True + break + + recommender_resp_entity = get_entity(recommender_text, self.entity2id) + + context.append(recommender_text) + entities += recommender_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + 'entity': recommender_resp_entity, + 'pred_items': pred_items, + 'rec_items': items, + 'rec_success': rec_success + }) + + # seeker + if rec_success is True: + seeker_text = "That's perfect, thank you!" + else: + seeker_text = "I don't like them." + + context_list.append({ + 'role': 'user', + 'content': seeker_text + }) + + context.append(seeker_text) + batch['context'] = [copy(context)] + + else: # choose to ask + recommender_text = option2template[selected_option] + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + }) + context.append(recommender_text) + batch['context'] = [copy(context)] + + # seeker + ask_attr = option2attr[selected_option] + + # update state + state[option2index[selected_option]] = -1e5 + + ans_attr_list = [] + id2info_items = [self.entityid2id[item] for item in items] + for id2info_item in id2info_items: + if str(id2info_item) in self.id2info and ask_attr in self.id2info[str(id2info_item)]: + ans_attr_list.extend(self.id2info[str(id2info_item)][ask_attr]) + if len(ans_attr_list) > 0: + seeker_text = ', '.join(list(set(ans_attr_list))) + else: + seeker_text = 'Sorry, no information about this, please choose another option.' + + context_list.append({ + 'role': 'user', + 'content': seeker_text, + 'entity': ans_attr_list, + }) + + seeker_resp_entities = get_entity(seeker_text, self.entity2id) + + context.append(seeker_text) + entities += seeker_resp_entities + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + if rec_success is True: + break + + return context_list + \ No newline at end of file diff --git a/crslab/evaluator/chat.py b/crslab/evaluator/chat.py new file mode 100644 index 0000000..f81c297 --- /dev/null +++ b/crslab/evaluator/chat.py @@ -0,0 +1,275 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import argparse +import copy +import json +import os +import re +import random +import time +import typing +import warnings +import tiktoken + +import numpy as np +import openai +import nltk + +from copy import copy +from tqdm import tqdm +from loguru import logger +from thefuzz import fuzz +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import get_entity, get_instruction + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +class Chat(): + + def __init__(self, turn_num, crs_model, dataset) -> None: + self.turn_num = turn_num + self.crs_model = crs_model + self.dataset_path = os.path.join(DATASET_PATH, dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, dataset, 'embed') + item_emb_list = [] + id2item_id = [] + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entity = {} + for k, v in self.entity2id.items(): + self.id2entity[int(v)] = k + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + + def chat(self, batch, turn_num): + + recommender_instruction, seeker_instruction_template = get_instruction() + + contexts_batch = batch['context'] + items_batch = batch['item'] + entity_batch = batch['entity'] + + for context, items, entities in zip(contexts_batch, items_batch, entity_batch): + + goal_item_list = [self.id2entity[item] for item in items] + goal_item_str = ', '.join(goal_item_list) + seeker_prompt = seeker_instruction_template.format(goal_item_str, goal_item_str, goal_item_str, goal_item_str) + + context_list = [] + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + seeker_prompt += f'Seeker: {text}\n' + context_list.append({ + 'role': 'user', + 'content': text, + }) + else: + seeker_prompt += f'Recommender: {text}\n' + context_list.append({ + 'role': 'assistant', + 'content': text + }) + + rec_success = False + recommendation_template = "I would recommend the following items: {}:" + + for i in range(0, turn_num): + # rec only + _, item_rank_arr = self.crs_model.recommend(batch) + + pred_items = item_rank_arr[0] + + for rec_label in items: + if rec_label in pred_items: + rec_success = True + break + + _, recommender_text = self.crs_model.converse(batch) + + if rec_success == True or i == turn_num - 1: + rec_items_str = '' + for j, rec_item in enumerate(pred_items): + rec_items_str += f"{j+1}: {self.id2entity[rec_item]}\n" + recommendation_template = recommendation_template.format(rec_items_str) + recommender_text = recommendation_template + recommender_text + + recommender_resp_entity = get_entity(recommender_text, self.entity2id) + + context.append(recommender_text) + entities += recommender_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + 'entity': recommender_resp_entity, + 'pred_items': pred_items, + 'rec_items': items, + 'rec_success': rec_success + }) + + seeker_prompt += f'Recommender: {recommender_text}\nSeeker:' + + # seeker + year_pattern = re.compile(r'\(\d+\)') + goal_item_no_year_list = [year_pattern.sub('', rec_item).strip() for rec_item in goal_item_list] + seeker_text = annotate_completion(seeker_prompt).strip() + + seeker_response_no_movie_list = [] + for sent in nltk.sent_tokenize(seeker_text): + use_sent = True + for rec_item_str in goal_item_list + goal_item_no_year_list: + if fuzz.partial_ratio(rec_item_str.lower(), sent.lower()) > 90: + use_sent = False + break + if use_sent is True: + seeker_response_no_movie_list.append(sent) + seeker_response = ' '.join(seeker_response_no_movie_list) + if not rec_success: + seeker_response = 'Sorry, ' + seeker_response + seeker_prompt += f' {seeker_response}\n' + + # public + seeker_resp_entity = get_entity(seeker_text, self.entity2id) + + context_list.append({ + 'role': 'user', + 'content': seeker_text, + 'entity': seeker_resp_entity, + }) + + context.append(seeker_text) + entities += seeker_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + if rec_success: + break + + # score persuativeness + encoding = tiktoken.encoding_for_model("text-davinci-003") + logit_bias = {encoding.encode(str(score))[0]: 10 for score in range(3)} + + persuasiveness_template = '''Does the explanation make you want to accept the recommendation? Please give your score. + If mention one of [{}], give 2. + Else if you think recommended items are worse than [{}], give 0. + Else if you think recommended items are comparable to [{}] according to the explanation, give 1. + Else if you think recommended items are better than [{}] according to the explanation, give 2. + Only answer the score number.''' + + persuasiveness_template = persuasiveness_template.format(goal_item_str, goal_item_str, goal_item_str, goal_item_str) + prompt_str_for_persuasiveness = seeker_prompt + persuasiveness_template + prompt_str_for_persuasiveness += "\nSeeker:" + persuasiveness_score = annotate_completion(prompt_str_for_persuasiveness, logit_bias).strip() + + context_list.append({ + 'persuasiveness_score': persuasiveness_score + }) + + return context_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..33044a6 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -14,7 +14,7 @@ 'zh': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', 'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3' ) @@ -22,7 +22,7 @@ 'en': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', '96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e' ) diff --git a/crslab/evaluator/rec.py b/crslab/evaluator/rec.py index ce2a6c4..dc15c67 100644 --- a/crslab/evaluator/rec.py +++ b/crslab/evaluator/rec.py @@ -25,17 +25,18 @@ class RecEvaluator(BaseEvaluator): optim_metrics: the metrics to optimize in training """ - def __init__(self, tensorboard=False): + def __init__(self, k_list=[1, 10, 50], tensorboard=False): super(RecEvaluator, self).__init__() self.rec_metrics = Metrics() self.optim_metrics = Metrics() self.tensorboard = tensorboard + self.k_list = k_list if self.tensorboard: self.writer = SummaryWriter(log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) self.reports_name = ['Recommendation Metrics', 'Optimization Metrics'] def rec_evaluate(self, ranks, label): - for k in [1, 10, 50]: + for k in self.k_list: if len(ranks) >= k: self.rec_metrics.add(f"hit@{k}", HitMetric.compute(ranks, label, k)) self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) diff --git a/crslab/evaluator/utils.py b/crslab/evaluator/utils.py index 0f5f89c..dce24fa 100644 --- a/crslab/evaluator/utils.py +++ b/crslab/evaluator/utils.py @@ -16,6 +16,7 @@ import math import torch from typing import Union, Tuple +from rapidfuzz import fuzz, process from .metrics import Metric @@ -158,3 +159,23 @@ def nice_report(report) -> str: for k, v in output.items() } ) + +def get_entity(text, entity2id): + entity_list = list(entity2id.keys()) + extractions = process.extract(text, entity_list, scorer=fuzz.WRatio, limit=20) + extractions = [extraction[0] for extraction in extractions if extraction[1] >= 90] + extractions_ids = [entity2id[extraction] for extraction in extractions] + return extractions_ids + +def get_instruction(): + + recommender_instruction = '''You are a recommender chatting with the user to provide recommendation. You must follow the instructions below during chat. + If you do not have enough information about user preference, you should ask the user for his preference. + If you have enough information about user preference, you can give recommendation.\n''' + seeker_instruction_template = '''You are a seeker chatting with a recommender for recommendation. Your target items: {}. You must follow the instructions below during chat. + If the recommender recommend {}, you should accept. + If the recommender recommend other items, you should refuse them and provide the information about {}. You should never directly tell the target item title. + If the recommender asks for your preference, you should provide the information about {}. You should never directly tell the target item title\n. + ''' + + return recommender_instruction, seeker_instruction_template \ No newline at end of file diff --git a/crslab/model/__init__.py b/crslab/model/__init__.py index 9b32a78..419b4e3 100644 --- a/crslab/model/__init__.py +++ b/crslab/model/__init__.py @@ -41,7 +41,8 @@ 'GRU4REC': GRU4RECModel, 'Popularity': PopularityModel, 'TextCNN': TextCNNModel, - 'NTRD': NTRDModel + 'NTRD': NTRDModel, + 'ChatGPT': ChatGPTModel, } diff --git a/crslab/model/crs/__init__.py b/crslab/model/crs/__init__.py index 0633769..ab75ed7 100644 --- a/crslab/model/crs/__init__.py +++ b/crslab/model/crs/__init__.py @@ -4,3 +4,4 @@ from .redial import * from .tgredial import * from .ntrd import * +from .chatgpt import * diff --git a/crslab/model/crs/chatgpt/__init__.py b/crslab/model/crs/chatgpt/__init__.py new file mode 100644 index 0000000..0ce0a8c --- /dev/null +++ b/crslab/model/crs/chatgpt/__init__.py @@ -0,0 +1 @@ +from .chatgpt import ChatGPTModel diff --git a/crslab/model/crs/chatgpt/chatgpt.py b/crslab/model/crs/chatgpt/chatgpt.py new file mode 100644 index 0000000..c786b38 --- /dev/null +++ b/crslab/model/crs/chatgpt/chatgpt.py @@ -0,0 +1,279 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import json +import os +import numpy as np +import openai +import typing +import tiktoken + +from tqdm import tqdm +from loguru import logger +from copy import copy +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base +from sklearn.metrics.pairwise import cosine_similarity + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.model.base import BaseModel + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate(item_text_list): + request_timeout = 6 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=my_before_sleep + ): + with attempt: + response = openai.Embedding.create( + model='text-embedding-ada-002', input=item_text_list, request_timeout=request_timeout + ) + request_timeout = min(30, request_timeout * 2) + + return response + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +def annotate_chat(messages, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=my_before_sleep + ): + with attempt: + response = openai.ChatCompletion.create( + model='gpt-3.5-turbo', messages=messages, temperature=0, logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['message']['content'] + request_timeout = min(300, request_timeout * 2) + + return response + +class ChatGPTModel(BaseModel): + + def __init__(self, opt, device, vocab=None, side_data=None): + self.dataset = opt['dataset'] + self.dataset_path = os.path.join(DATASET_PATH, self.dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, self.dataset, 'embed') + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + self.get_item_embedding() + super(ChatGPTModel, self).__init__(opt, device) + + def build_model(self, *args, **kwargs): + return super().build_model(*args, **kwargs) + + def get_item_embedding(self): + + item_emb_list = [] + id2item_id = [] + + if os.path.exists(self.item_embedding_path): + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + def get_instruction(self): + + recommender_instruction = '''You are a recommender chatting with the user to provide recommendation. You must follow the instructions below during chat. +If you do not have enough information about user preference, you should ask the user for his preference. +If you have enough information about user preference, you can give recommendation. The recommendation list must contain 10 items that are consistent with user preference. The recommendation list can contain items that the dialog mentioned before. The format of the recommendation list is: no. title. Don't mention anything other than the title of items in your recommendation list. +''' + + seeker_instruction_template = '''You are a seeker chatting with a recommender for recommendation. Your target items: {}. You must follow the instructions below during chat. + If the recommender recommend {}, you should accept. + If the recommender recommend other items, you should refuse them and provide the information about {}. You should never directly tell the target item title. + If the recommender asks for your preference, you should provide the information about {}. You should never directly tell the target item title\n. + ''' + + return recommender_instruction, seeker_instruction_template + + def recommend(self, batch, mode='test'): + + context_batch = batch['context'] + item_rank_arr_batch = [] + + for context in context_batch: + + if len(context) % 2 == 0: + context = [""] + context + + conv_str = "" + + for i, text in enumerate(context[-2:]): + if len(text) == 0: + continue + if i % 2 == 0: + conv_str += f'Seeker: {text}\n' + else: + conv_str += f'Recommender: {text}\n' + conv_embed = annotate(conv_str)['data'][0]['embedding'] + conv_embed = np.asarray(conv_embed).reshape(1, -1) + + sim_mat = cosine_similarity(conv_embed, self.item_emb_arr) + rank_arr = np.argsort(sim_mat, axis=-1).tolist() + rank_arr = np.flip(rank_arr, axis=-1)[:, :50] + item_rank_arr = self.id2item_id_arr[rank_arr].tolist() + # modify item_rank_arr, item_id -> entity_id + item_rank_arr = [self.id2entityid[item_id] for item_id in item_rank_arr[0]] + item_rank_arr_batch.append(item_rank_arr) + + loss = None + + return loss, item_rank_arr_batch + + def converse(self, batch, mode='test'): + recommender_instruction, seeker_instruction_template = self.get_instruction() + + context_batch = batch['context'] + item_batch = batch['item'] + + for context, items in zip(context_batch, item_batch): + + context_list = [{ + 'role': 'system', + 'content': recommender_instruction + }] + + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + context_list.append({ + 'role': 'user', + 'content': text, + }) + else: + context_list.append({ + 'role': 'assistant', + 'content': text + }) + + gen_str = annotate_chat(context_list) + gen_inputs = None + + return gen_inputs, gen_str + + def choose(self, gen_inputs, options, state, batch, mode='test'): + + context_batch = batch['context'] + + for context in context_batch: + + updated_options = [] + for i, st in enumerate(state): + if st >= 0: + updated_options.append(options[i]) + + logger.info(updated_options) + + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + logit_bias = {encoding.encode(option)[0]: 20 for option in updated_options} + + context_list = [] + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + role_str = 'user' + else: + role_str = 'assistant' + context_list.append({ + 'role': role_str, + 'content': text + }) + + logger.info(context_list) + + response_op = annotate_chat(context_list, logit_bias=logit_bias) + return response_op[0] \ No newline at end of file diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py index d484a3f..c32dcd2 100644 --- a/crslab/model/crs/kgsf/resources.py +++ b/crslab/model/crs/kgsf/resources.py @@ -14,7 +14,7 @@ 'ReDial': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', 'kgsf_redial.zip', 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', ), @@ -22,7 +22,7 @@ 'TGReDial': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', 'kgsf_tgredial.zip', 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', ), @@ -30,15 +30,15 @@ 'GoRecDial': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1', 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', + 'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633', ) }, 'OpenDialKG': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', 'kgsf_opendialkg.zip', '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' ) @@ -46,7 +46,7 @@ 'Inspired': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', 'kgsf_inspired.zip', '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' ) @@ -54,7 +54,7 @@ 'DuRecDial': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', 'kgsf_durecdial.zip', 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' ) diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py index d484a3f..c32dcd2 100644 --- a/crslab/model/crs/ntrd/resources.py +++ b/crslab/model/crs/ntrd/resources.py @@ -14,7 +14,7 @@ 'ReDial': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', 'kgsf_redial.zip', 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', ), @@ -22,7 +22,7 @@ 'TGReDial': { 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', 'kgsf_tgredial.zip', 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', ), @@ -30,15 +30,15 @@ 'GoRecDial': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1', 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', + 'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633', ) }, 'OpenDialKG': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', 'kgsf_opendialkg.zip', '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' ) @@ -46,7 +46,7 @@ 'Inspired': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', 'kgsf_inspired.zip', '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' ) @@ -54,7 +54,7 @@ 'DuRecDial': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', 'kgsf_durecdial.zip', 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' ) diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py index 33c20d6..e254bdf 100644 --- a/crslab/model/pretrained_models.py +++ b/crslab/model/pretrained_models.py @@ -29,7 +29,7 @@ 'zh': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', 'bert_zh.zip', 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' ) @@ -37,7 +37,7 @@ 'en': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', 'bert_en.zip', '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' ) @@ -47,7 +47,7 @@ 'zh': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', 'gpt2_zh.zip', '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' ) @@ -55,7 +55,7 @@ 'en': { 'version': '0.1', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', 'gpt2_en.zip', '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' ) diff --git a/crslab/quick_start/__init__.py b/crslab/quick_start/__init__.py index 12b84a5..a03ec24 100644 --- a/crslab/quick_start/__init__.py +++ b/crslab/quick_start/__init__.py @@ -1 +1 @@ -from .quick_start import run_crslab +from .quick_start import quick_start \ No newline at end of file diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..fa31de2 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -13,7 +13,7 @@ from crslab.system import get_system -def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False, +def quick_start(config, mode, save_data=False, restore_data=False, save_system=False, restore_system=False, interact=False, debug=False, tensorboard=False): """A fast running api, which includes the complete process of training and testing models on specified datasets. @@ -70,6 +70,6 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if interact: CRS.interact() else: - CRS.fit() + CRS.fit(mode) if save_system: CRS.save_model() diff --git a/crslab/system/__init__.py b/crslab/system/__init__.py index dbbd699..5c66c85 100644 --- a/crslab/system/__init__.py +++ b/crslab/system/__init__.py @@ -21,6 +21,7 @@ from .redial import ReDialSystem from .ntrd import NTRDSystem from .tgredial import TGReDialSystem +from .chatgpt import ChatGPTSystem system_register_table = { 'ReDialRec_ReDialConv': ReDialSystem, @@ -41,7 +42,8 @@ 'GRU4REC': TGReDialSystem, 'Popularity': TGReDialSystem, 'TextCNN': TGReDialSystem, - 'NTRD': NTRDSystem + 'NTRD': NTRDSystem, + 'ChatGPT': ChatGPTSystem } diff --git a/crslab/system/chatgpt.py b/crslab/system/chatgpt.py new file mode 100644 index 0000000..d9a06ef --- /dev/null +++ b/crslab/system/chatgpt.py @@ -0,0 +1,199 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +import json +import random + +import openai +from tqdm import tqdm +from loguru import logger +from tenacity import _utils, Retrying, retry_if_not_exception_type +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.system.base import BaseSystem +from crslab.data.dataset import BaseDataset +from crslab.data import dataset_register_table +from crslab.model import Model_register_table +from crslab.evaluator.chat import my_wait_exponential, my_stop_after_attempt, Chat +from crslab.evaluator.ask import Ask +from crslab.evaluator.rec import RecEvaluator +from crslab.system.utils.functions import get_exist_item_set + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +class ChatGPTSystem(BaseSystem): + + def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False, interact=False, debug=False, tensorboard=False): + super(ChatGPTSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) + + openai.api_key = opt['api_key'] + self.dataset = opt['dataset'] + self.dpath = os.path.join(DATASET_PATH, opt['dataset']) + self.embed_save_dir = os.path.join(SAVE_PATH, self.dataset, 'embed') + os.makedirs(self.embed_save_dir, exist_ok=True) + self.chat_save_dir = os.path.join(SAVE_PATH, self.dataset, 'chat') + os.makedirs(self.chat_save_dir, exist_ok=True) + self.ask_save_dir = os.path.join(SAVE_PATH, self.dataset, 'ask') + os.makedirs(self.ask_save_dir, exist_ok=True) + with open(os.path.join(self.dpath, 'id2info.json'), 'r', encoding='utf-8') as f: + self.id2info = json.load(f) + self.test_dataloader = test_dataloader + self.rec_optim_opt = opt['rec'] + self.conv_optim_opt = opt['conv'] + self.cache_item_opt = opt['cache_item'] + self.rec_batch_size = self.rec_optim_opt['batch_size'] + self.conv_batch_size = self.conv_optim_opt['batch_size'] + self.cache_item_batch_size= self.cache_item_opt['batch_size'] + self.api_key = opt['api_key'] + self.turn_num = opt['turn_num'] + self.dataset_class = dataset_register_table[self.dataset](opt, opt['tokenize']) + crs_model_name = opt['model_name'] + self.crs_model = Model_register_table[crs_model_name](opt, opt['device']) + + def my_before_sleep(self, retry_state): + logger.debug(f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + def annotate(self, item_text_list): + request_timeout = 6 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=self.my_before_sleep + ): + with attempt: + response = openai.Embedding.create( + model='text-embedding-ada-002', input=item_text_list, request_timeout=request_timeout + ) + request_timeout = min(30, request_timeout * 2) + + return response + + def cache_item(self): + attr_list = self.dataset_class.get_attr_list() + id2text = {} + for item_id, info_dict in self.id2info.items(): + attr_str_list = [f'Title: {info_dict["name"]}'] + for attr in attr_list: + if attr not in info_dict: + continue + if isinstance(info_dict[attr], list): + value_str = ', '.join(info_dict[attr]) + else: + value_str = info_dict[attr] + attr_str_list.append(f'{attr.capitalize()}: {value_str}') + item_text = '; '.join(attr_str_list) + id2text[item_id] = item_text + + item_ids = set(self.id2info.keys()) - get_exist_item_set(self.embed_save_dir) + while len(item_ids) > 0: + logger.info(len(item_ids)) + batch_item_ids = random.sample(tuple(item_ids), min(self.cache_item_batch_size, len(item_ids))) + batch_texts = [id2text[item_id] for item_id in batch_item_ids] + + batch_embeds = self.annotate(batch_texts)['data'] + for embed in batch_embeds: + item_id = batch_item_ids[embed['index']] + with open(f'{self.embed_save_dir}/{item_id}.json', 'w', encoding='utf-8') as f: + json.dump(embed['embedding'], f, ensure_ascii=False) + + item_ids -= get_exist_item_set(self.embed_save_dir) + + + def iEvaLM_chat(self): + logger.info('[Test]') + iEvaLM_CHAT = Chat(self.turn_num, self.crs_model, self.dataset) + dataid2data = {} + for i, batch in enumerate(self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False)): + dialog_ids = batch['dialog_id'] + for dialog_id in dialog_ids: + dataid2data[dialog_id] = batch + + dialog_id_set = set(dataid2data.keys()) - get_exist_dialog_set(self.chat_save_dir) + while len(dialog_id_set) > 0: + logger.info(len(dialog_id_set)) + dialog_id = random.choice(tuple(dialog_id_set)) + batch_data = dataid2data[dialog_id] + returned_data = iEvaLM_CHAT.chat(batch_data, self.turn_num) + + with open(f'{self.chat_save_dir}/{dialog_id}.json', 'w', encoding='utf-8') as f: + json.dump(returned_data, f, ensure_ascii=False, indent=2) + + dialog_id_set -= get_exist_dialog_set(self.chat_save_dir) + + def iEvaLM_ask(self): + logger.info('[Test]') + ask_instruction_dict = self.dataset_class.get_ask_instruction() + iEvaLM_ASK = Ask(self.turn_num, self.crs_model, self.dataset, ask_instruction_dict) + dataid2data = {} + for i, batch in enumerate(self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False)): + dialog_ids = batch['dialog_id'] + for dialog_id in dialog_ids: + dataid2data[dialog_id] = batch + + dialog_id_set = set(dataid2data.keys()) - get_exist_dialog_set(self.ask_save_dir) + while len(dialog_id_set) > 0: + logger.info(len(dialog_id_set)) + dialog_id = random.choice(tuple(dialog_id_set)) + batch_data = dataid2data[dialog_id] + returned_data = iEvaLM_ASK.ask(batch_data, self.turn_num) + + with open(f'{self.ask_save_dir}/{dialog_id}.json', 'w', encoding='utf-8') as f: + json.dump(returned_data, f, ensure_ascii=False, indent=2) + + dialog_id_set -= get_exist_dialog_set(self.ask_save_dir) + + def step(self, batch, stage, mode): + return super().step(batch, stage, mode) + + def evaluate_iEvaLM(self, mode): + metric = RecEvaluator(k_list=[1, 10, 25, 50]) + persuatiness_list = [] + if mode == 'chat': + save_path = self.chat_save_dir + elif mode == 'ask': + save_path = self.ask_save_dir + if os.path.exists(save_path) and len(os.listdir(save_path)) > 0: + path_list = os.listdir(save_path) + print(save_path, len(path_list)) + + for path in tqdm(path_list): + with open(f"{save_path}/{path}", 'r', encoding="utf-8") as f: + context_list = json.load(f) + if mode == 'chat': + persuasiveness_score = context_list[-1]['persuasiveness_score'] + persuatiness_list.append(float(persuasiveness_score)) + # TODO: modify chatgpt evaluator + for context in context_list[::-1]: + if 'rec_items' in context: + rec_labels = context['rec_items'] + rec_items = context['pred_items'] + for rec_label in rec_labels: + metric.rec_evaluate(rec_items, rec_label) + break + + metric.report(mode='test') + if mode == 'chat': + avg_persuatiness_score = sum(persuatiness_list) / len(persuatiness_list) + logger.info(avg_persuatiness_score) + + def fit(self, mode): + self.cache_item() + if mode == 'chat': + self.iEvaLM_chat() + elif mode == 'ask': + self.iEvaLM_ask() + else: + raise ValueError(f'Invalid mode: {mode}') + self.evaluate_iEvaLM(mode) + + def interact(self): + pass + \ No newline at end of file diff --git a/crslab/system/utils/functions.py b/crslab/system/utils/functions.py index a622f36..630797f 100644 --- a/crslab/system/utils/functions.py +++ b/crslab/system/utils/functions.py @@ -12,6 +12,7 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +import os import torch @@ -64,3 +65,10 @@ def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',s def ind2slot(inds,ind2slot): return [ ind2slot[ind] for ind in inds] + +def get_exist_item_set(save_dir): + exist_item_set = set() + for file in os.listdir(save_dir): + user_id = os.path.splitext(file)[0] + exist_item_set.add(user_id) + return exist_item_set \ No newline at end of file diff --git a/run_crslab.py b/run_crslab.py index b0f7f73..4dd4d40 100644 --- a/run_crslab.py +++ b/run_crslab.py @@ -35,10 +35,12 @@ help='interact with your system instead of training') parser.add_argument('-tb', '--tensorboard', action='store_true', help='enable tensorboard to monitor train performance') + parser.add_argument('--mode', choices=['train', 'chat', 'ask'], + help='Train CRS model / Evaluating the CRS model with iEvaLM free-form chit-chat mode / attribute-based question answering mode') args, _ = parser.parse_known_args() config = Config(args.config, args.gpu, args.debug) - from crslab.quick_start import run_crslab - - run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, + from crslab.quick_start import quick_start + + quick_start(config, args.mode, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, args.debug, args.tensorboard)