From 502463da3574ba3c91d2ad71eb451c81f9ed2e90 Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 28 Sep 2022 22:56:26 +0800 Subject: [PATCH 01/35] txy --- config/crs/kgsf/durecdial.yaml | 1 + config/crs/kgsf/gorecdial.yaml | 1 + config/crs/kgsf/inspired.yaml | 1 + config/crs/kgsf/opendialkg.yaml | 1 + config/crs/kgsf/redial.yaml | 1 + config/crs/kgsf/tgredial.yaml | 1 + config/crs/ntrd/tgredial.yaml | 1 + crslab/config/__init__.py | 9 + crslab/data/dataset/durecdial/durecdial.py | 160 ++++++++++++++- crslab/data/dataset/durecdial/resources.py | 99 +++++----- crslab/data/dataset/gorecdial/gorecdial.py | 178 ++++++++++++++++- crslab/data/dataset/gorecdial/resources.py | 96 +++++---- crslab/data/dataset/inspired/inspired.py | 185 +++++++++++++++++- crslab/data/dataset/inspired/resources.py | 89 ++++----- crslab/data/dataset/opendialkg/opendialkg.py | 178 ++++++++++++++++- crslab/data/dataset/opendialkg/resources.py | 89 ++++----- crslab/data/dataset/redial/redial.py | 176 ++++++++++++++++- crslab/data/dataset/redial/resources.py | 89 ++++----- crslab/data/dataset/tgredial/resources.py | 99 +++++----- crslab/data/dataset/tgredial/tgredial.py | 182 ++++++++++++++++- crslab/data/dataset/tokenize.py | 67 +++++++ crslab/evaluator/embeddings.py | 9 +- crslab/evaluator/standard.py | 12 +- crslab/model/conversation/gpt2/gpt2.py | 26 ++- crslab/model/crs/inspired/inspired_conv.py | 42 +++- crslab/model/crs/inspired/inspired_rec.py | 26 ++- crslab/model/crs/kgsf/kgsf.py | 13 +- crslab/model/crs/kgsf/resources.py | 62 ------ crslab/model/crs/ntrd/ntrd.py | 12 +- crslab/model/crs/ntrd/resources.py | 62 ------ crslab/model/crs/redial/modules.py | 11 +- crslab/model/crs/tgredial/tg_conv.py | 26 ++- crslab/model/crs/tgredial/tg_policy.py | 26 ++- crslab/model/crs/tgredial/tg_rec.py | 26 ++- crslab/model/policy/conv_bert/conv_bert.py | 26 ++- .../model/policy/profile_bert/profile_bert.py | 26 ++- crslab/model/policy/topic_bert/topic_bert.py | 26 ++- crslab/model/pretrained_models.py | 64 ------ crslab/model/recommendation/bert/bert.py | 26 ++- crslab/system/kgsf.py | 7 + crslab/system/tgredial.py | 7 + requirements.txt | 1 + 42 files changed, 1633 insertions(+), 606 deletions(-) create mode 100644 crslab/data/dataset/tokenize.py delete mode 100644 crslab/model/crs/kgsf/resources.py delete mode 100644 crslab/model/crs/ntrd/resources.py delete mode 100644 crslab/model/pretrained_models.py diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..9ad0a9d 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..ab00260 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index f087ca3..c3608e5 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..09d47c3 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..5d11ca1 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..33b2e1a 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..49c7940 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -24,6 +24,7 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' +copy: true # optim pretrain: epoch: 50 diff --git a/crslab/config/__init__.py b/crslab/config/__init__.py index f7556ef..d7cdb09 100644 --- a/crslab/config/__init__.py +++ b/crslab/config/__init__.py @@ -8,6 +8,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + """Config module which loads parameters for the whole system. Attributes: @@ -30,3 +35,7 @@ MODEL_PATH = os.path.join(DATA_PATH, 'model') PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain') EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding') +BERT_EN_PATH = os.path.join(PRETRAIN_PATH, 'bert', 'en') +BERT_ZH_PATH = os.path.join(PRETRAIN_PATH, 'bert', 'zh') +GPT2_EN_PATH = os.path.join(PRETRAIN_PATH, 'gpt2', 'en') +GPT2_ZH_PATH = os.path.join(PRETRAIN_PATH, 'gpt2', 'zh') \ No newline at end of file diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..fb3745d 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" DuRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_ZH_PATH, GPT2_ZH_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class DuRecDialDataset(BaseDataset): """ @@ -65,10 +72,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_ZH_PATH + elif tokenize == 'gpt2': + self.path = GPT2_ZH_PATH + self.crstokenizer = CrsTokenize('zh', tokenize, self.path) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -94,14 +113,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -262,3 +302,111 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + cnt = 0 + tok2ind = {} + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + tokenizer = self.tokenize + crstokenize = self.crstokenizer + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + match_list = list(set(match_list)) + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..bd2348b 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,63 +8,58 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'jieba': { - 'version': '0.3', + 'resource':{ + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', - 'durecdial_jieba.zip', - 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', + 'durecdial.zip', + '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'jieba': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', - 'durecdial_bert.zip', - '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', - 'durecdial_gpt2.zip', - 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..dd6f7f7 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" GoRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_EN_PATH, GPT2_EN_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class GoRecDialDataset(BaseDataset): """ @@ -65,10 +72,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_EN_PATH + elif tokenize == 'gpt2': + self.path = GPT2_EN_PATH + self.crstokenizer = CrsTokenize('en', tokenize, self.path) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +114,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +306,129 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..5ea42c1 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,61 +8,57 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', - 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', + 'gorecdial.zip', + '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'nltk': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', - 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', - 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + }, + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..25cc7e6 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" Inspired ======== @@ -21,13 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_EN_PATH, GPT2_EN_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class InspiredDataset(BaseDataset): @@ -65,10 +73,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + dpath = os.path.join(DATASET_PATH, 'inspired') + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_EN_PATH + elif tokenize == 'gpt2': + self.path = GPT2_EN_PATH + self.crstokenizer = CrsTokenize('en', tokenize, self.path) super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -268,3 +309,137 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for genre in dialog['genre']: + list_word = crstokenize.tokenize(genre, tokenizer) + match_list += list_word + + for people in dialog['people']: + list_word = crstokenize.tokenize(people, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..c2d1e75 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', - 'inspired_nltk.zip', - '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', + 'inspired.zip', + '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', - 'inspired_bert.zip', - '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', - 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..e02ae7e 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" OpenDialKG ========== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_EN_PATH, GPT2_EN_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class OpenDialKGDataset(BaseDataset): @@ -66,10 +74,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + dpath = os.path.join(DATASET_PATH, 'opendialkg') + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_EN_PATH + elif tokenize == 'gpt2': + self.path = GPT2_EN_PATH + self.crstokenizer = CrsTokenize('en', tokenize, self.path) super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +116,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -271,3 +312,130 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + np.save(word2vec_path, word2embedding) + diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..e5682fe 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', - 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', + 'opendialkg.zip', + '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', - 'opendialkg_bert.zip', - '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', - 'opendialkg_gpt2.zip', - 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..6557618 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" ReDial ====== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_EN_PATH, GPT2_EN_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class ReDialDataset(BaseDataset): @@ -66,10 +74,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + dpath = os.path.join(DATASET_PATH, "redial") + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_EN_PATH + elif tokenize == 'gpt2': + self.path = GPT2_EN_PATH + self.crstokenizer = CrsTokenize('en', tokenize, self.path) super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +116,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +307,128 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..170dd3b 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.31', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', - 'redial_nltk.zip', - '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', + 'redial.zip', + '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, - }, - 'bert': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', - 'redial_bert.zip', - 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', - 'redial_gpt2.zip', - '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', - ), - 'special_token_idx': { - 'pad': -100, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..92506f7 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,64 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'pkuseg': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', - 'tgredial_pkuseg.zip', - '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', + 'tgredial.zip', + '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'pkuseg':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', - 'tgredial_bert.zip', - 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', - 'tgredial_gpt2.zip', - '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..1f52044 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" TGReDial ======== @@ -23,12 +28,15 @@ from collections import defaultdict from copy import copy import numpy as np +import gensim + from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, BERT_ZH_PATH, GPT2_ZH_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class TGReDialDataset(BaseDataset): @@ -69,11 +77,25 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + + self.tokenize = tokenize + self.path = None + if tokenize == 'bert': + self.path = BERT_ZH_PATH + elif tokenize == 'gpt2': + self.path = GPT2_ZH_PATH + self.crstokenizer = CrsTokenize('zh', tokenize, self.path) + dpath = os.path.join(DATASET_PATH, 'tgredial') + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) @@ -111,14 +133,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -340,3 +383,132 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + each_dict['conv_id'] = each['conv_id'] + for one in each['messages']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['messages'] = each_data + each_dict['user_id'] = each['user_id'] + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['messages'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['messages']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + + for movie in dialog['movie']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['messages']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/tokenize.py b/crslab/data/dataset/tokenize.py new file mode 100644 index 0000000..1cc302e --- /dev/null +++ b/crslab/data/dataset/tokenize.py @@ -0,0 +1,67 @@ +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +from nltk import word_tokenize +from transformers import AutoTokenizer +import pkuseg +import nltk +import jieba + +class CrsTokenize: + + def __init__(self, language, tokenizer=None, path=None) -> None: + self.language = language + self.path = path + + if tokenizer == 'bert': + if language == 'zh': + if os.path.exists(path): + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + else: + os.environ['TORCH_HOME'] = path + self.my_tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese') + elif language == 'en': + if os.path.exists(self.path): + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + else: + os.environ['TORCH_HOME'] = path + self.my_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + elif tokenizer == 'gpt2': + if language == 'zh': + if os.path.exists(path): + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + else: + os.environ['TORCH_HOME'] = path + self.my_tokenizer = AutoTokenizer.from_pretrained('GPT2-chitchat') + elif language == 'en': + if os.path.exists(path): + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + else: + os.environ['TORCH_HOME'] = path + self.my_tokenizer = AutoTokenizer.from_pretrained('gpt2') + + def tokenize(self, text, tokenizer): + tokenize_fun = getattr(self, tokenizer + '_tokenize') + return tokenize_fun(text) + + def nltk_tokenize(self, text): + # nltk.download('punkt') + return word_tokenize(text) + + def bert_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def gpt2_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def pkuseg_tokenize(self, text): + if not hasattr(self, 'pkuseg_tokenizer'): + self.pkuseg_tokenizer = pkuseg.pkuseg() + return self.pkuseg_tokenizer.cut(text) + + def jieba_tokenize(self, text): + split_text = jieba.cut(text) + text_list = ' '.join(split_text).split() + return text_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..b682e42 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,11 +8,16 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '0.2', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', @@ -20,7 +25,7 @@ ) }, 'en': { - 'version': '0.2', + 'version': '.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f08d121 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import time from collections import defaultdict @@ -83,9 +88,10 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + if len(ref_embs[0]) > 0: + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index c93badb..069c6c2 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" GPT2 ==== @@ -24,10 +29,9 @@ from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, GPT2_ZH_PATH, GPT2_EN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class GPT2Model(BaseModel): @@ -54,14 +58,22 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(GPT2Model, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "gpt2", self.language) + super(GPT2Model, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" - self.model = GPT2LMHeadModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.model = GPT2LMHeadModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = GPT2_ZH_PATH + self.model = GPT2LMHeadModel.from_pretrained('GPT2-chitchat') + elif self.language == 'en': + os.environ['TORCH_HOME'] = GPT2_EN_PATH + self.model = GPT2LMHeadModel.from_pretrained('gpt2') self.loss = CrossEntropyLoss(ignore_index=self.pad_id) def forward(self, batch, mode): diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 99e7ca9..5644365 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,15 +2,19 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import os +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com +import os +import json import torch from transformers import GPT2LMHeadModel -from crslab.config import PRETRAIN_PATH +from crslab.config import BERT_EN_PATH, BERT_ZH_PATH, PRETRAIN_PATH, GPT2_ZH_PATH, GPT2_EN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -39,14 +43,22 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(InspiredConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "gpt2", self.language) + super(InspiredConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model for seeker and recommender separately""" - self.model_sk = GPT2LMHeadModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.model_sk = GPT2LMHeadModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = GPT2_ZH_PATH + self.model_sk = GPT2LMHeadModel.from_pretrained('GPT2-chitchat') + elif self.language == 'en': + os.environ['TORCH_HOME'] = GPT2_EN_PATH + self.model_sk = GPT2LMHeadModel.from_pretrained('gpt2') self.model_rm = GPT2LMHeadModel.from_pretrained(self.dpath) self.loss = SequenceCrossEntropyLoss(self.pad_id, self.label_smoothing) @@ -68,17 +80,27 @@ def converse(self, batch, mode): past = None lm_logits_all = [] + if self.language == 'zh': + config_json = os.path.join(GPT2_ZH_PATH, 'config.json') + elif self.language == 'en': + config_json = os.path.join(GPT2_EN_PATH, 'config.json') + + with open(config_json, 'r', encoding='utf-8') as f: + json_config = json.load(f) + + support_up_limits = json_config['n_ctx'] + if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 67948f5..ffe6052 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -24,10 +29,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class InspiredRecModel(BaseModel): @@ -50,14 +54,22 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(InspiredRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(InspiredRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters - self.bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.bert = BertModel.from_pretrained('bert-base-uncased') # print(self.item_size) self.bert_hidden_size = self.bert.config.hidden_size self.mlp = nn.Linear(self.bert_hidden_size, self.item_size) diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 57590f4..1230ec8 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" KGSF ==== @@ -33,7 +38,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG -from .resources import resources class KGSFModel(BaseModel): @@ -116,10 +120,9 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] - super(KGSFModel, self).__init__(opt, device, dpath, resource) + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + super(KGSFModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/kgsf/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 0f971b4..ef85782 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,6 +3,10 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" NTRD @@ -29,7 +33,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection -from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -87,12 +90,11 @@ def __init__(self, opt, device, vocab, side_data): # self.n_movies_label = opt['n_movies_label'] self.n_movies_label = 64362 # the number of entity2id # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, dpath, resource) + super(NTRDModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/ntrd/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index a726524..f202dcb 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +76,7 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) @@ -104,7 +109,7 @@ def forward(self, context, utterance_lengths, dialog_lengths): # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -144,7 +149,7 @@ def forward(self, request, request_lengths, context_state): sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 9e505d5..8d42dc4 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Conv ============= @@ -24,10 +29,9 @@ from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, GPT2_ZH_PATH, GPT2_EN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGConvModel(BaseModel): @@ -54,14 +58,22 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) - super(TGConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "gpt2", self.language) + super(TGConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" - self.model = GPT2LMHeadModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.model = GPT2LMHeadModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = GPT2_ZH_PATH + self.model = GPT2LMHeadModel.from_pretrained('GPT2-chitchat') + elif self.language == 'en': + os.environ['TORCH_HOME'] = GPT2_EN_PATH + self.model = GPT2LMHeadModel.from_pretrained('gpt2') self.loss = CrossEntropyLoss(ignore_index=self.pad_id) def forward(self, batch, mode): diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 708b7f9..8088e5a 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Policy =============== @@ -24,10 +29,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGPolicyModel(BaseModel): @@ -44,14 +48,22 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGPolicyModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(TGPolicyModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" - self.context_bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.context_bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.context_bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.context_bert = BertModel.from_pretrained('bert-base-uncased') self.topic_bert = BertModel.from_pretrained(self.dpath) self.profile_bert = BertModel.from_pretrained(self.dpath) diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index a02ac5b..f5dd02f 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Rec ============ @@ -25,10 +30,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from crslab.model.recommendation.sasrec.modules import SASRec @@ -68,14 +72,22 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(TGRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters - self.bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.bert = BertModel.from_pretrained('bert-base-uncased') self.bert_hidden_size = self.bert.config.hidden_size self.concat_embed_size = self.bert_hidden_size + self.hidden_size self.fusion = nn.Linear(self.concat_embed_size, self.item_size) diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 76101cc..0580a9a 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Conv_BERT ========= @@ -23,10 +28,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from ...pretrained_models import resources class ConvBERTModel(BaseModel): @@ -48,14 +52,22 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ConvBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(ConvBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" - self.context_bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.context_bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.context_bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.context_bert = BertModel.from_pretrained('bert-base-uncased') self.bert_hidden_size = self.context_bert.config.hidden_size self.state2topic_id = nn.Linear(self.bert_hidden_size, diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 65b400f..dd43d45 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Profile_BERT ============ @@ -24,10 +29,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class ProfileBERTModel(BaseModel): @@ -52,14 +56,22 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(ProfileBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" - self.profile_bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.profile_bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.profile_bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.profile_bert = BertModel.from_pretrained('bert-base-uncased') self.bert_hidden_size = self.profile_bert.config.hidden_size self.state2topic_id = nn.Linear(self.bert_hidden_size, diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 400eaeb..d2f5944 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Topic_BERT ========== @@ -23,10 +28,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TopicBERTModel(BaseModel): @@ -50,14 +54,22 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - resource = resources['bert'][language] - super(TopicBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(TopicBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" - self.topic_bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.topic_bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.topic_bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.topic_bert = BertModel.from_pretrained('bert-base-uncased') self.bert_hidden_size = self.topic_bert.config.hidden_size self.state2topic_id = nn.Linear(self.bert_hidden_size, diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py deleted file mode 100644 index 33c20d6..0000000 --- a/crslab/model/pretrained_models.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2021/1/6 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2021/1/7 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -"""Download links of pretrain models. - -Now we provide the following models: - -- `BERT`_: zh, en -- `GPT2`_: zh, en - -.. _BERT: - https://www.aclweb.org/anthology/N19-1423/ -.. _GPT2: - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf - -""" - -resources = { - 'bert': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', - 'bert_zh.zip', - 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', - 'bert_en.zip', - '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' - ) - }, - }, - 'gpt2': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', - 'gpt2_zh.zip', - '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', - 'gpt2_en.zip', - '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' - ) - } - } -} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index cb78a7b..c70342b 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -24,10 +29,9 @@ from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH +from crslab.config import PRETRAIN_PATH, BERT_EN_PATH, BERT_ZH_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class BERTModel(BaseModel): @@ -50,14 +54,22 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(BERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = os.path.join(PRETRAIN_PATH, "bert", self.language) + super(BERTModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters - self.bert = BertModel.from_pretrained(self.dpath) + if os.path.exists(self.dpath): + self.bert = BertModel.from_pretrained(self.dpath) + else: + os.makedirs(self.dpath) + if self.language == 'zh': + os.environ['TORCH_HOME'] = BERT_ZH_PATH + self.bert = BertModel.from_pretrained('base-base-chinese') + elif self.language == 'en': + os.environ['TORCH_HOME'] = BERT_EN_PATH + self.bert = BertModel.from_pretrained('bert-base-uncased') # print(self.item_size) self.bert_hidden_size = self.bert.config.hidden_size self.mlp = nn.Linear(self.bert_hidden_size, self.item_size) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..bb839a5 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -154,6 +159,8 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b..96251c5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -169,6 +174,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + bert_param = list(self.rec_model.bert.named_parameters()) else: bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] diff --git a/requirements.txt b/requirements.txt index f7fba73..05950a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 +gensim From ee1d80a56bec3a1d1c979ebb06973b7fade65782 Mon Sep 17 00:00:00 2001 From: txy77 Date: Thu, 29 Sep 2022 17:54:24 +0800 Subject: [PATCH 02/35] txy77 --- config/conversation/gpt2/durecdial.yaml | 4 + config/conversation/gpt2/gorecdial.yaml | 4 + config/conversation/gpt2/inspired.yaml | 4 + config/conversation/gpt2/opendialkg.yaml | 4 + config/conversation/gpt2/redial.yaml | 4 + config/conversation/gpt2/tgredial.yaml | 4 + config/crs/inspired/durecdial.yaml | 7 + config/crs/inspired/gorecdial.yaml | 7 + config/crs/inspired/inspired.yaml | 7 + config/crs/inspired/opendialkg.yaml | 7 + config/crs/inspired/redial.yaml | 7 + config/crs/inspired/tgredial.yaml | 7 + config/crs/kgsf/durecdial.yaml | 1 + config/crs/kgsf/gorecdial.yaml | 1 + config/crs/kgsf/inspired.yaml | 1 + config/crs/kgsf/opendialkg.yaml | 1 + config/crs/kgsf/redial.yaml | 1 + config/crs/kgsf/tgredial.yaml | 1 + config/crs/ntrd/tgredial.yaml | 1 + config/crs/tgredial/durecdial.yaml | 6 + config/crs/tgredial/gorecdial.yaml | 6 + config/crs/tgredial/inspired.yaml | 6 + config/crs/tgredial/opendialkg.yaml | 6 + config/crs/tgredial/redial.yaml | 6 + config/crs/tgredial/tgredial.yaml | 8 + config/policy/conv_bert/tgredial.yaml | 4 + config/policy/mgcg/tgredial.yaml | 2 + config/policy/pmi/tgredial.yaml | 2 + config/policy/profile_bert/tgredial.yaml | 4 + config/policy/topic_bert/tgredial.yaml | 4 + config/recommendation/bert/durecdial.yaml | 4 + config/recommendation/bert/gorecdial.yaml | 4 + config/recommendation/bert/inspired.yaml | 4 + config/recommendation/bert/opendialkg.yaml | 4 + config/recommendation/bert/redial.yaml | 4 + config/recommendation/bert/tgredial.yaml | 4 + config/recommendation/gru4rec/durecdial.yaml | 2 + config/recommendation/gru4rec/gorecdial.yaml | 2 + config/recommendation/gru4rec/inspired.yaml | 2 + config/recommendation/gru4rec/opendialkg.yaml | 2 + config/recommendation/gru4rec/redial.yaml | 2 + config/recommendation/gru4rec/tgredial.yaml | 2 + .../recommendation/popularity/durecdial.yaml | 2 + .../recommendation/popularity/gorecdial.yaml | 2 + .../recommendation/popularity/inspired.yaml | 2 + .../recommendation/popularity/opendialkg.yaml | 2 + config/recommendation/popularity/redial.yaml | 2 + .../recommendation/popularity/tgredial.yaml | 2 + config/recommendation/sasrec/durecdial.yaml | 2 + config/recommendation/sasrec/gorecdial.yaml | 2 + config/recommendation/sasrec/inspired.yaml | 2 + config/recommendation/sasrec/opendialkg.yaml | 2 + config/recommendation/sasrec/redial.yaml | 2 + config/recommendation/sasrec/tgredial.yaml | 2 + config/recommendation/textcnn/tgredial.yaml | 2 +- crslab/data/__init__.py | 4 +- crslab/data/dataset/durecdial/durecdial.py | 161 ++++++++++++++- crslab/data/dataset/durecdial/resources.py | 99 +++++----- crslab/data/dataset/gorecdial/gorecdial.py | 179 ++++++++++++++++- crslab/data/dataset/gorecdial/resources.py | 96 +++++---- crslab/data/dataset/inspired/inspired.py | 186 +++++++++++++++++- crslab/data/dataset/inspired/resources.py | 89 ++++----- crslab/data/dataset/opendialkg/opendialkg.py | 179 ++++++++++++++++- crslab/data/dataset/opendialkg/resources.py | 89 ++++----- crslab/data/dataset/redial/redial.py | 177 ++++++++++++++++- crslab/data/dataset/redial/resources.py | 89 ++++----- crslab/data/dataset/tgredial/resources.py | 99 +++++----- crslab/data/dataset/tgredial/tgredial.py | 183 ++++++++++++++++- crslab/data/dataset/tokenize.py | 42 ++++ crslab/evaluator/embeddings.py | 9 +- crslab/evaluator/standard.py | 12 +- crslab/model/conversation/gpt2/gpt2.py | 13 +- crslab/model/crs/inspired/inspired_conv.py | 26 ++- crslab/model/crs/inspired/inspired_rec.py | 13 +- crslab/model/crs/kgsf/kgsf.py | 13 +- crslab/model/crs/kgsf/resources.py | 62 ------ crslab/model/crs/ntrd/ntrd.py | 12 +- crslab/model/crs/ntrd/resources.py | 62 ------ crslab/model/crs/redial/modules.py | 11 +- crslab/model/crs/tgredial/tg_conv.py | 13 +- crslab/model/crs/tgredial/tg_policy.py | 13 +- crslab/model/crs/tgredial/tg_rec.py | 13 +- crslab/model/policy/conv_bert/conv_bert.py | 13 +- .../model/policy/profile_bert/profile_bert.py | 13 +- crslab/model/policy/topic_bert/topic_bert.py | 13 +- crslab/model/pretrained_models.py | 64 ------ crslab/model/recommendation/bert/bert.py | 13 +- crslab/quick_start/quick_start.py | 4 +- crslab/system/kgsf.py | 7 + crslab/system/tgredial.py | 7 + requirements.txt | 1 + 91 files changed, 1671 insertions(+), 597 deletions(-) create mode 100644 crslab/data/dataset/tokenize.py delete mode 100644 crslab/model/crs/kgsf/resources.py delete mode 100644 crslab/model/crs/ntrd/resources.py delete mode 100644 crslab/model/pretrained_models.py diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 92a5329..05f568e 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index ea155c4..abedfcb 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index b620579..69a2208 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index d96e8d6..20e0020 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 3a89ac8..69756b3 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index 378d9af..1566760 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 50 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index 6984c40..6068285 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index e44800b..77647e1 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index a992737..3b22889 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index ff3c13a..8e4b879 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index df25019..8e6d4ff 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index 892eb20..34684a1 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -3,6 +3,9 @@ dataset: TGReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..9ad0a9d 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..ab00260 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index f087ca3..c3608e5 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..09d47c3 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..5d11ca1 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..33b2e1a 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..49c7940 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -24,6 +24,7 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 08a96aa..cfd5cf9 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 74382ff..67c6411 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index f4ace12..87edf15 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 1 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index bcfb217..cba24ed 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 8e983a0..31dd6ad 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index 0e1c956..ef5d57a 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -4,6 +4,10 @@ tokenize: rec: bert conv: gpt2 policy: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -13,6 +17,10 @@ scale: 1 rec_model: TGRec conv_model: TGConv policy_model: TGPolicy +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +policy_pretrained_path: 'data/model/pretrain/bert/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 78e5c58..284aa86 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ConvBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 7cd78ec..8726cec 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 87bb5e6..8e8b50b 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 39f9ae8..08068a9 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ProfileBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index c3a5253..aed3b69 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: TopicBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index 0d4250a..fcb981c 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 22ff335..864ed06 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index d2d9d18..9a854fd 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 4b59696..fcc40f5 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index be5fa53..820d894 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 717a2ab..3ac3319 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 94a5f6a..233f43f 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 0d80c59..ca66dd7 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index 8ef81fe..5488b5e 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index b4900b9..809202b 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 7b707e7..21fc6ca 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 7caf3d0..14fa628 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index 3131e0a..f1b03c2 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index bf77cd6..768d369 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index 4c9a821..cea0dce 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index ebaf2c9..c88d0c1 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index b0cbec9..2afc85e 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 66c9ef7..c8e6230 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index 15ba15e..bcf5e8b 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 243a646..3ec5786 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index d79ff24..51f5e6c 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index ba4c02d..42a8edf 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index add69ec..7df885a 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index 9888002..c8c3353 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/textcnn/tgredial.yaml b/config/recommendation/textcnn/tgredial.yaml index 0d5c708..0de66df 100644 --- a/config/recommendation/textcnn/tgredial.yaml +++ b/config/recommendation/textcnn/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: - rec: sougou + rec: jieba # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 33bea19..7a4ad30 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -70,7 +70,7 @@ } -def get_dataset(opt, tokenize, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: """get and process dataset Args: @@ -85,7 +85,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, restore, save) + return dataset_register_table[dataset](opt, tokenize, restore, save, task) else: raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..d06727a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" DuRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class DuRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -94,14 +112,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -262,3 +301,111 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + cnt = 0 + tok2ind = {} + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + tokenizer = self.tokenize + crstokenize = self.crstokenizer + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + match_list = list(set(match_list)) + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..bd2348b 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,63 +8,58 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'jieba': { - 'version': '0.3', + 'resource':{ + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', - 'durecdial_jieba.zip', - 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', + 'durecdial.zip', + '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'jieba': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', - 'durecdial_bert.zip', - '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', - 'durecdial_gpt2.zip', - 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..07b553d 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" GoRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class GoRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +113,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +305,129 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..5ea42c1 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,61 +8,57 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', - 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', + 'gorecdial.zip', + '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'nltk': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', - 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', - 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + }, + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..190826d 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" Inspired ======== @@ -21,13 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class InspiredDataset(BaseDataset): @@ -55,7 +63,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +73,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +114,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -268,3 +308,137 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for genre in dialog['genre']: + list_word = crstokenize.tokenize(genre, tokenizer) + match_list += list_word + + for people in dialog['people']: + list_word = crstokenize.tokenize(people, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..c2d1e75 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', - 'inspired_nltk.zip', - '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', + 'inspired.zip', + '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', - 'inspired_bert.zip', - '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', - 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..66fadb9 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" OpenDialKG ========== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class OpenDialKGDataset(BaseDataset): @@ -56,7 +64,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -271,3 +311,130 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + np.save(word2vec_path, word2embedding) + diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..e5682fe 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', - 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', + 'opendialkg.zip', + '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', - 'opendialkg_bert.zip', - '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', - 'opendialkg_gpt2.zip', - 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..e75c52c 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" ReDial ====== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class ReDialDataset(BaseDataset): @@ -56,7 +64,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +306,128 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..170dd3b 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.31', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', - 'redial_nltk.zip', - '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', + 'redial.zip', + '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, - }, - 'bert': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', - 'redial_bert.zip', - 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', - 'redial_gpt2.zip', - '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', - ), - 'special_token_idx': { - 'pad': -100, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..92506f7 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,64 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'pkuseg': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', - 'tgredial_pkuseg.zip', - '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', + 'tgredial.zip', + '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'pkuseg':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', - 'tgredial_bert.zip', - 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', - 'tgredial_gpt2.zip', - '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..c935029 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" TGReDial ======== @@ -23,12 +28,15 @@ from collections import defaultdict from copy import copy import numpy as np +import gensim + from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class TGReDialDataset(BaseDataset): @@ -59,7 +67,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -69,11 +77,24 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'tgredial') + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) @@ -111,14 +132,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -340,3 +382,132 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + each_dict['conv_id'] = each['conv_id'] + for one in each['messages']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['messages'] = each_data + each_dict['user_id'] = each['user_id'] + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['messages'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['messages']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + + for movie in dialog['movie']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['messages']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/tokenize.py b/crslab/data/dataset/tokenize.py new file mode 100644 index 0000000..c63352f --- /dev/null +++ b/crslab/data/dataset/tokenize.py @@ -0,0 +1,42 @@ +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +from nltk import word_tokenize +from transformers import AutoTokenizer +import pkuseg +import nltk +import jieba + +class CrsTokenize: + + def __init__(self, path=None) -> None: + self.path = path + + if path is not None: + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + + def tokenize(self, text, tokenizer): + tokenize_fun = getattr(self, tokenizer + '_tokenize') + return tokenize_fun(text) + + def nltk_tokenize(self, text): + # nltk.download('punkt') + return word_tokenize(text) + + def bert_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def gpt2_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def pkuseg_tokenize(self, text): + if not hasattr(self, 'pkuseg_tokenizer'): + self.pkuseg_tokenizer = pkuseg.pkuseg() + return self.pkuseg_tokenizer.cut(text) + + def jieba_tokenize(self, text): + split_text = jieba.cut(text) + text_list = ' '.join(split_text).split() + return text_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..b682e42 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,11 +8,16 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '0.2', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', @@ -20,7 +25,7 @@ ) }, 'en': { - 'version': '0.2', + 'version': '.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f08d121 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import time from collections import defaultdict @@ -83,9 +88,10 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + if len(ref_embs[0]) > 0: + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index c93badb..5e84a8c 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" GPT2 ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class GPT2Model(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(GPT2Model, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(GPT2Model, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 99e7ca9..30fe87e 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,15 +2,19 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import os +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com +import os +import json import torch from transformers import GPT2LMHeadModel from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -39,10 +43,9 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(InspiredConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(InspiredConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model for seeker and recommender separately""" @@ -68,17 +71,24 @@ def converse(self, batch, mode): past = None lm_logits_all = [] + config_json = os.path.join(self.dpath, 'config.json') + + with open(config_json, 'r', encoding='utf-8') as f: + json_config = json.load(f) + + support_up_limits = json_config['n_ctx'] + if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 67948f5..2b2e94b 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class InspiredRecModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(InspiredRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(InspiredRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 57590f4..1230ec8 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" KGSF ==== @@ -33,7 +38,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG -from .resources import resources class KGSFModel(BaseModel): @@ -116,10 +120,9 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] - super(KGSFModel, self).__init__(opt, device, dpath, resource) + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + super(KGSFModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/kgsf/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 0f971b4..ef85782 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,6 +3,10 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" NTRD @@ -29,7 +33,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection -from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -87,12 +90,11 @@ def __init__(self, opt, device, vocab, side_data): # self.n_movies_label = opt['n_movies_label'] self.n_movies_label = 64362 # the number of entity2id # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, dpath, resource) + super(NTRDModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/ntrd/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index a726524..f202dcb 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +76,7 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) @@ -104,7 +109,7 @@ def forward(self, context, utterance_lengths, dialog_lengths): # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -144,7 +149,7 @@ def forward(self, request, request_lengths, context_state): sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 9e505d5..7a6a81c 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Conv ============= @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGConvModel(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) - super(TGConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(TGConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 708b7f9..6986be5 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Policy =============== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGPolicyModel(BaseModel): @@ -44,10 +48,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGPolicyModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TGPolicyModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index a02ac5b..ad185e7 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Rec ============ @@ -28,7 +33,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from crslab.model.recommendation.sasrec.modules import SASRec @@ -68,10 +72,9 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(TGRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 76101cc..117d760 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Conv_BERT ========= @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from ...pretrained_models import resources class ConvBERTModel(BaseModel): @@ -48,10 +52,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ConvBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ConvBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 65b400f..d7cbced 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Profile_BERT ============ @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class ProfileBERTModel(BaseModel): @@ -52,10 +56,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ProfileBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 400eaeb..b20d11a 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Topic_BERT ========== @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TopicBERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - resource = resources['bert'][language] - super(TopicBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TopicBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py deleted file mode 100644 index 33c20d6..0000000 --- a/crslab/model/pretrained_models.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2021/1/6 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2021/1/7 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -"""Download links of pretrain models. - -Now we provide the following models: - -- `BERT`_: zh, en -- `GPT2`_: zh, en - -.. _BERT: - https://www.aclweb.org/anthology/N19-1423/ -.. _GPT2: - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf - -""" - -resources = { - 'bert': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', - 'bert_zh.zip', - 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', - 'bert_en.zip', - '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' - ) - }, - }, - 'gpt2': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', - 'gpt2_zh.zip', - '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', - 'gpt2_en.zip', - '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' - ) - } - } -} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index cb78a7b..a053eea 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class BERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(BERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(BERTModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..2199396 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -34,7 +34,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data, task=None) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -53,7 +53,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - dataset = get_dataset(config, tokenize, restore_data, save_data) + dataset = get_dataset(config, tokenize, restore_data, save_data, task) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..bb839a5 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -154,6 +159,8 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b..96251c5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -169,6 +174,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + bert_param = list(self.rec_model.bert.named_parameters()) else: bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] diff --git a/requirements.txt b/requirements.txt index f7fba73..05950a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 +gensim From 10c36fe796fa73b8e7732312a09da2514cec4e0b Mon Sep 17 00:00:00 2001 From: txy77 Date: Thu, 29 Sep 2022 18:05:15 +0800 Subject: [PATCH 03/35] txy77 --- config/conversation/gpt2/durecdial.yaml | 4 + config/conversation/gpt2/gorecdial.yaml | 4 + config/conversation/gpt2/inspired.yaml | 4 + config/conversation/gpt2/opendialkg.yaml | 4 + config/conversation/gpt2/redial.yaml | 4 + config/conversation/gpt2/tgredial.yaml | 4 + config/crs/inspired/durecdial.yaml | 7 + config/crs/inspired/gorecdial.yaml | 7 + config/crs/inspired/inspired.yaml | 7 + config/crs/inspired/opendialkg.yaml | 7 + config/crs/inspired/redial.yaml | 7 + config/crs/inspired/tgredial.yaml | 7 + config/crs/kgsf/durecdial.yaml | 1 + config/crs/kgsf/gorecdial.yaml | 1 + config/crs/kgsf/inspired.yaml | 1 + config/crs/kgsf/opendialkg.yaml | 1 + config/crs/kgsf/redial.yaml | 1 + config/crs/kgsf/tgredial.yaml | 1 + config/crs/ntrd/tgredial.yaml | 1 + config/crs/tgredial/durecdial.yaml | 6 + config/crs/tgredial/gorecdial.yaml | 6 + config/crs/tgredial/inspired.yaml | 6 + config/crs/tgredial/opendialkg.yaml | 6 + config/crs/tgredial/redial.yaml | 6 + config/crs/tgredial/tgredial.yaml | 8 + config/policy/conv_bert/tgredial.yaml | 4 + config/policy/mgcg/tgredial.yaml | 2 + config/policy/pmi/tgredial.yaml | 2 + config/policy/profile_bert/tgredial.yaml | 4 + config/policy/topic_bert/tgredial.yaml | 4 + config/recommendation/bert/durecdial.yaml | 4 + config/recommendation/bert/gorecdial.yaml | 4 + config/recommendation/bert/inspired.yaml | 4 + config/recommendation/bert/opendialkg.yaml | 4 + config/recommendation/bert/redial.yaml | 4 + config/recommendation/bert/tgredial.yaml | 4 + config/recommendation/gru4rec/durecdial.yaml | 2 + config/recommendation/gru4rec/gorecdial.yaml | 2 + config/recommendation/gru4rec/inspired.yaml | 2 + config/recommendation/gru4rec/opendialkg.yaml | 2 + config/recommendation/gru4rec/redial.yaml | 2 + config/recommendation/gru4rec/tgredial.yaml | 2 + .../recommendation/popularity/durecdial.yaml | 2 + .../recommendation/popularity/gorecdial.yaml | 2 + .../recommendation/popularity/inspired.yaml | 2 + .../recommendation/popularity/opendialkg.yaml | 2 + config/recommendation/popularity/redial.yaml | 2 + .../recommendation/popularity/tgredial.yaml | 2 + config/recommendation/sasrec/durecdial.yaml | 2 + config/recommendation/sasrec/gorecdial.yaml | 2 + config/recommendation/sasrec/inspired.yaml | 2 + config/recommendation/sasrec/opendialkg.yaml | 2 + config/recommendation/sasrec/redial.yaml | 2 + config/recommendation/sasrec/tgredial.yaml | 2 + config/recommendation/textcnn/tgredial.yaml | 2 +- crslab/data/__init__.py | 4 +- crslab/data/dataset/durecdial/durecdial.py | 161 ++++++++++++++- crslab/data/dataset/durecdial/resources.py | 99 +++++----- crslab/data/dataset/gorecdial/gorecdial.py | 179 ++++++++++++++++- crslab/data/dataset/gorecdial/resources.py | 96 +++++---- crslab/data/dataset/inspired/inspired.py | 186 +++++++++++++++++- crslab/data/dataset/inspired/resources.py | 89 ++++----- crslab/data/dataset/opendialkg/opendialkg.py | 179 ++++++++++++++++- crslab/data/dataset/opendialkg/resources.py | 89 ++++----- crslab/data/dataset/redial/redial.py | 177 ++++++++++++++++- crslab/data/dataset/redial/resources.py | 89 ++++----- crslab/data/dataset/tgredial/resources.py | 99 +++++----- crslab/data/dataset/tgredial/tgredial.py | 183 ++++++++++++++++- crslab/data/dataset/tokenize.py | 42 ++++ crslab/evaluator/embeddings.py | 9 +- crslab/evaluator/standard.py | 12 +- crslab/model/conversation/gpt2/gpt2.py | 13 +- crslab/model/crs/inspired/inspired_conv.py | 29 ++- crslab/model/crs/inspired/inspired_rec.py | 13 +- crslab/model/crs/kgsf/kgsf.py | 13 +- crslab/model/crs/kgsf/resources.py | 62 ------ crslab/model/crs/ntrd/ntrd.py | 12 +- crslab/model/crs/ntrd/resources.py | 62 ------ crslab/model/crs/redial/modules.py | 11 +- crslab/model/crs/tgredial/tg_conv.py | 13 +- crslab/model/crs/tgredial/tg_policy.py | 13 +- crslab/model/crs/tgredial/tg_rec.py | 13 +- crslab/model/policy/conv_bert/conv_bert.py | 13 +- .../model/policy/profile_bert/profile_bert.py | 13 +- crslab/model/policy/topic_bert/topic_bert.py | 13 +- crslab/model/pretrained_models.py | 64 ------ crslab/model/recommendation/bert/bert.py | 13 +- crslab/quick_start/quick_start.py | 4 +- crslab/system/kgsf.py | 7 + crslab/system/tgredial.py | 7 + requirements.txt | 1 + 91 files changed, 1674 insertions(+), 597 deletions(-) create mode 100644 crslab/data/dataset/tokenize.py delete mode 100644 crslab/model/crs/kgsf/resources.py delete mode 100644 crslab/model/crs/ntrd/resources.py delete mode 100644 crslab/model/pretrained_models.py diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 92a5329..05f568e 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index ea155c4..abedfcb 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index b620579..69a2208 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index d96e8d6..20e0020 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 3a89ac8..69756b3 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index 378d9af..1566760 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 50 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index 6984c40..6068285 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index e44800b..77647e1 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index a992737..3b22889 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index ff3c13a..8e4b879 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index df25019..8e6d4ff 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index 892eb20..34684a1 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -3,6 +3,9 @@ dataset: TGReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..9ad0a9d 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..ab00260 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index f087ca3..c3608e5 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..09d47c3 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..5d11ca1 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..33b2e1a 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..49c7940 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -24,6 +24,7 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 08a96aa..cfd5cf9 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 74382ff..67c6411 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index f4ace12..87edf15 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 1 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index bcfb217..cba24ed 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 8e983a0..31dd6ad 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index 0e1c956..ef5d57a 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -4,6 +4,10 @@ tokenize: rec: bert conv: gpt2 policy: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -13,6 +17,10 @@ scale: 1 rec_model: TGRec conv_model: TGConv policy_model: TGPolicy +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +policy_pretrained_path: 'data/model/pretrain/bert/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 78e5c58..284aa86 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ConvBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 7cd78ec..8726cec 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 87bb5e6..8e8b50b 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 39f9ae8..08068a9 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ProfileBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index c3a5253..aed3b69 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: TopicBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index 0d4250a..fcb981c 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 22ff335..864ed06 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index d2d9d18..9a854fd 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 4b59696..fcc40f5 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index be5fa53..820d894 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 717a2ab..3ac3319 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 94a5f6a..233f43f 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 0d80c59..ca66dd7 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index 8ef81fe..5488b5e 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index b4900b9..809202b 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 7b707e7..21fc6ca 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 7caf3d0..14fa628 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index 3131e0a..f1b03c2 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index bf77cd6..768d369 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index 4c9a821..cea0dce 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index ebaf2c9..c88d0c1 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index b0cbec9..2afc85e 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 66c9ef7..c8e6230 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index 15ba15e..bcf5e8b 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 243a646..3ec5786 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index d79ff24..51f5e6c 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index ba4c02d..42a8edf 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index add69ec..7df885a 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index 9888002..c8c3353 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/textcnn/tgredial.yaml b/config/recommendation/textcnn/tgredial.yaml index 0d5c708..0de66df 100644 --- a/config/recommendation/textcnn/tgredial.yaml +++ b/config/recommendation/textcnn/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: - rec: sougou + rec: jieba # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 33bea19..7a4ad30 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -70,7 +70,7 @@ } -def get_dataset(opt, tokenize, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: """get and process dataset Args: @@ -85,7 +85,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, restore, save) + return dataset_register_table[dataset](opt, tokenize, restore, save, task) else: raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..d06727a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" DuRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class DuRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -94,14 +112,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -262,3 +301,111 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + cnt = 0 + tok2ind = {} + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + tokenizer = self.tokenize + crstokenize = self.crstokenizer + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + match_list = list(set(match_list)) + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..bd2348b 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,63 +8,58 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'jieba': { - 'version': '0.3', + 'resource':{ + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', - 'durecdial_jieba.zip', - 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', + 'durecdial.zip', + '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'jieba': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', - 'durecdial_bert.zip', - '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', - 'durecdial_gpt2.zip', - 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..07b553d 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" GoRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class GoRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +113,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +305,129 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..5ea42c1 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,61 +8,57 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', - 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', + 'gorecdial.zip', + '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'nltk': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', - 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', - 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + }, + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..190826d 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" Inspired ======== @@ -21,13 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class InspiredDataset(BaseDataset): @@ -55,7 +63,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +73,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +114,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -268,3 +308,137 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for genre in dialog['genre']: + list_word = crstokenize.tokenize(genre, tokenizer) + match_list += list_word + + for people in dialog['people']: + list_word = crstokenize.tokenize(people, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..c2d1e75 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', - 'inspired_nltk.zip', - '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', + 'inspired.zip', + '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', - 'inspired_bert.zip', - '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', - 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..66fadb9 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" OpenDialKG ========== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class OpenDialKGDataset(BaseDataset): @@ -56,7 +64,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -271,3 +311,130 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + np.save(word2vec_path, word2embedding) + diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..e5682fe 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', - 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', + 'opendialkg.zip', + '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', - 'opendialkg_bert.zip', - '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', - 'opendialkg_gpt2.zip', - 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..e75c52c 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" ReDial ====== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class ReDialDataset(BaseDataset): @@ -56,7 +64,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +306,128 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..170dd3b 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.31', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', - 'redial_nltk.zip', - '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', + 'redial.zip', + '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, - }, - 'bert': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', - 'redial_bert.zip', - 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', - 'redial_gpt2.zip', - '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', - ), - 'special_token_idx': { - 'pad': -100, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..92506f7 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,64 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'pkuseg': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', - 'tgredial_pkuseg.zip', - '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', + 'tgredial.zip', + '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'pkuseg':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', - 'tgredial_bert.zip', - 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', - 'tgredial_gpt2.zip', - '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..c935029 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" TGReDial ======== @@ -23,12 +28,15 @@ from collections import defaultdict from copy import copy import numpy as np +import gensim + from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class TGReDialDataset(BaseDataset): @@ -59,7 +67,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -69,11 +77,24 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'tgredial') + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) @@ -111,14 +132,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -340,3 +382,132 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + each_dict['conv_id'] = each['conv_id'] + for one in each['messages']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['messages'] = each_data + each_dict['user_id'] = each['user_id'] + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['messages'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['messages']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + + for movie in dialog['movie']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['messages']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/tokenize.py b/crslab/data/dataset/tokenize.py new file mode 100644 index 0000000..c63352f --- /dev/null +++ b/crslab/data/dataset/tokenize.py @@ -0,0 +1,42 @@ +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +from nltk import word_tokenize +from transformers import AutoTokenizer +import pkuseg +import nltk +import jieba + +class CrsTokenize: + + def __init__(self, path=None) -> None: + self.path = path + + if path is not None: + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + + def tokenize(self, text, tokenizer): + tokenize_fun = getattr(self, tokenizer + '_tokenize') + return tokenize_fun(text) + + def nltk_tokenize(self, text): + # nltk.download('punkt') + return word_tokenize(text) + + def bert_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def gpt2_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def pkuseg_tokenize(self, text): + if not hasattr(self, 'pkuseg_tokenizer'): + self.pkuseg_tokenizer = pkuseg.pkuseg() + return self.pkuseg_tokenizer.cut(text) + + def jieba_tokenize(self, text): + split_text = jieba.cut(text) + text_list = ' '.join(split_text).split() + return text_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..b682e42 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,11 +8,16 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '0.2', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', @@ -20,7 +25,7 @@ ) }, 'en': { - 'version': '0.2', + 'version': '.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f08d121 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import time from collections import defaultdict @@ -83,9 +88,10 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + if len(ref_embs[0]) > 0: + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index c93badb..5e84a8c 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" GPT2 ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class GPT2Model(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(GPT2Model, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(GPT2Model, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 99e7ca9..4286f9a 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,15 +2,19 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import os +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com +import os +import json import torch from transformers import GPT2LMHeadModel from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -39,10 +43,9 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(InspiredConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(InspiredConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model for seeker and recommender separately""" @@ -68,17 +71,27 @@ def converse(self, batch, mode): past = None lm_logits_all = [] + if self.language == 'zh': + config_json = os.path.join(GPT2_ZH_PATH, 'config.json') + elif self.language == 'en': + config_json = os.path.join(GPT2_EN_PATH, 'config.json') + + with open(config_json, 'r', encoding='utf-8') as f: + json_config = json.load(f) + + support_up_limits = json_config['n_ctx'] + if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 67948f5..2b2e94b 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class InspiredRecModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(InspiredRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(InspiredRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 57590f4..1230ec8 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" KGSF ==== @@ -33,7 +38,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG -from .resources import resources class KGSFModel(BaseModel): @@ -116,10 +120,9 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] - super(KGSFModel, self).__init__(opt, device, dpath, resource) + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + super(KGSFModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/kgsf/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 0f971b4..ef85782 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,6 +3,10 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" NTRD @@ -29,7 +33,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection -from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -87,12 +90,11 @@ def __init__(self, opt, device, vocab, side_data): # self.n_movies_label = opt['n_movies_label'] self.n_movies_label = 64362 # the number of entity2id # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, dpath, resource) + super(NTRDModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/ntrd/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index a726524..f202dcb 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +76,7 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) @@ -104,7 +109,7 @@ def forward(self, context, utterance_lengths, dialog_lengths): # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -144,7 +149,7 @@ def forward(self, request, request_lengths, context_state): sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 9e505d5..7a6a81c 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Conv ============= @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGConvModel(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) - super(TGConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(TGConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 708b7f9..6986be5 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Policy =============== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGPolicyModel(BaseModel): @@ -44,10 +48,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGPolicyModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TGPolicyModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index a02ac5b..ad185e7 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Rec ============ @@ -28,7 +33,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from crslab.model.recommendation.sasrec.modules import SASRec @@ -68,10 +72,9 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(TGRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 76101cc..117d760 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Conv_BERT ========= @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from ...pretrained_models import resources class ConvBERTModel(BaseModel): @@ -48,10 +52,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ConvBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ConvBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 65b400f..d7cbced 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Profile_BERT ============ @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class ProfileBERTModel(BaseModel): @@ -52,10 +56,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ProfileBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 400eaeb..b20d11a 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Topic_BERT ========== @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TopicBERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - resource = resources['bert'][language] - super(TopicBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TopicBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py deleted file mode 100644 index 33c20d6..0000000 --- a/crslab/model/pretrained_models.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2021/1/6 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2021/1/7 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -"""Download links of pretrain models. - -Now we provide the following models: - -- `BERT`_: zh, en -- `GPT2`_: zh, en - -.. _BERT: - https://www.aclweb.org/anthology/N19-1423/ -.. _GPT2: - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf - -""" - -resources = { - 'bert': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', - 'bert_zh.zip', - 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', - 'bert_en.zip', - '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' - ) - }, - }, - 'gpt2': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', - 'gpt2_zh.zip', - '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', - 'gpt2_en.zip', - '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' - ) - } - } -} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index cb78a7b..a053eea 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class BERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(BERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(BERTModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..2199396 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -34,7 +34,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data, task=None) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -53,7 +53,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - dataset = get_dataset(config, tokenize, restore_data, save_data) + dataset = get_dataset(config, tokenize, restore_data, save_data, task) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..bb839a5 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -154,6 +159,8 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b..96251c5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -169,6 +174,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + bert_param = list(self.rec_model.bert.named_parameters()) else: bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] diff --git a/requirements.txt b/requirements.txt index f7fba73..05950a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 +gensim From c8f1385539c68d3ece53e80a7ebe14101e2d1ca7 Mon Sep 17 00:00:00 2001 From: txy77 Date: Thu, 29 Sep 2022 19:03:12 +0800 Subject: [PATCH 04/35] txy77 --- config/conversation/gpt2/durecdial.yaml | 4 + config/conversation/gpt2/gorecdial.yaml | 4 + config/conversation/gpt2/inspired.yaml | 4 + config/conversation/gpt2/opendialkg.yaml | 4 + config/conversation/gpt2/redial.yaml | 4 + config/conversation/gpt2/tgredial.yaml | 4 + config/crs/inspired/durecdial.yaml | 7 + config/crs/inspired/gorecdial.yaml | 7 + config/crs/inspired/inspired.yaml | 7 + config/crs/inspired/opendialkg.yaml | 7 + config/crs/inspired/redial.yaml | 7 + config/crs/inspired/tgredial.yaml | 7 + config/crs/kgsf/durecdial.yaml | 1 + config/crs/kgsf/gorecdial.yaml | 1 + config/crs/kgsf/inspired.yaml | 1 + config/crs/kgsf/opendialkg.yaml | 1 + config/crs/kgsf/redial.yaml | 1 + config/crs/kgsf/tgredial.yaml | 1 + config/crs/ntrd/tgredial.yaml | 1 + config/crs/tgredial/durecdial.yaml | 6 + config/crs/tgredial/gorecdial.yaml | 6 + config/crs/tgredial/inspired.yaml | 6 + config/crs/tgredial/opendialkg.yaml | 6 + config/crs/tgredial/redial.yaml | 6 + config/crs/tgredial/tgredial.yaml | 8 + config/policy/conv_bert/tgredial.yaml | 4 + config/policy/mgcg/tgredial.yaml | 2 + config/policy/pmi/tgredial.yaml | 2 + config/policy/profile_bert/tgredial.yaml | 4 + config/policy/topic_bert/tgredial.yaml | 4 + config/recommendation/bert/durecdial.yaml | 4 + config/recommendation/bert/gorecdial.yaml | 4 + config/recommendation/bert/inspired.yaml | 4 + config/recommendation/bert/opendialkg.yaml | 4 + config/recommendation/bert/redial.yaml | 4 + config/recommendation/bert/tgredial.yaml | 4 + config/recommendation/gru4rec/durecdial.yaml | 2 + config/recommendation/gru4rec/gorecdial.yaml | 2 + config/recommendation/gru4rec/inspired.yaml | 2 + config/recommendation/gru4rec/opendialkg.yaml | 2 + config/recommendation/gru4rec/redial.yaml | 2 + config/recommendation/gru4rec/tgredial.yaml | 2 + .../recommendation/popularity/durecdial.yaml | 2 + .../recommendation/popularity/gorecdial.yaml | 2 + .../recommendation/popularity/inspired.yaml | 2 + .../recommendation/popularity/opendialkg.yaml | 2 + config/recommendation/popularity/redial.yaml | 2 + .../recommendation/popularity/tgredial.yaml | 2 + config/recommendation/sasrec/durecdial.yaml | 2 + config/recommendation/sasrec/gorecdial.yaml | 2 + config/recommendation/sasrec/inspired.yaml | 2 + config/recommendation/sasrec/opendialkg.yaml | 2 + config/recommendation/sasrec/redial.yaml | 2 + config/recommendation/sasrec/tgredial.yaml | 2 + config/recommendation/textcnn/tgredial.yaml | 2 +- crslab/data/__init__.py | 4 +- crslab/data/dataset/durecdial/durecdial.py | 161 ++++++++++++++- crslab/data/dataset/durecdial/resources.py | 99 +++++----- crslab/data/dataset/gorecdial/gorecdial.py | 179 ++++++++++++++++- crslab/data/dataset/gorecdial/resources.py | 96 +++++---- crslab/data/dataset/inspired/inspired.py | 186 +++++++++++++++++- crslab/data/dataset/inspired/resources.py | 89 ++++----- crslab/data/dataset/opendialkg/opendialkg.py | 179 ++++++++++++++++- crslab/data/dataset/opendialkg/resources.py | 89 ++++----- crslab/data/dataset/redial/redial.py | 177 ++++++++++++++++- crslab/data/dataset/redial/resources.py | 89 ++++----- crslab/data/dataset/tgredial/resources.py | 99 +++++----- crslab/data/dataset/tgredial/tgredial.py | 183 ++++++++++++++++- crslab/data/dataset/tokenize.py | 42 ++++ crslab/evaluator/embeddings.py | 9 +- crslab/evaluator/standard.py | 12 +- crslab/model/conversation/gpt2/gpt2.py | 13 +- crslab/model/crs/inspired/inspired_conv.py | 29 ++- crslab/model/crs/inspired/inspired_rec.py | 13 +- crslab/model/crs/kgsf/kgsf.py | 13 +- crslab/model/crs/kgsf/resources.py | 62 ------ crslab/model/crs/ntrd/ntrd.py | 12 +- crslab/model/crs/ntrd/resources.py | 62 ------ crslab/model/crs/redial/modules.py | 11 +- crslab/model/crs/tgredial/tg_conv.py | 13 +- crslab/model/crs/tgredial/tg_policy.py | 13 +- crslab/model/crs/tgredial/tg_rec.py | 13 +- crslab/model/policy/conv_bert/conv_bert.py | 13 +- .../model/policy/profile_bert/profile_bert.py | 13 +- crslab/model/policy/topic_bert/topic_bert.py | 13 +- crslab/model/pretrained_models.py | 64 ------ crslab/model/recommendation/bert/bert.py | 13 +- crslab/quick_start/quick_start.py | 4 +- crslab/system/kgsf.py | 7 + crslab/system/tgredial.py | 7 + requirements.txt | 1 + 91 files changed, 1674 insertions(+), 597 deletions(-) create mode 100644 crslab/data/dataset/tokenize.py delete mode 100644 crslab/model/crs/kgsf/resources.py delete mode 100644 crslab/model/crs/ntrd/resources.py delete mode 100644 crslab/model/pretrained_models.py diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 92a5329..05f568e 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index ea155c4..abedfcb 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index b620579..69a2208 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index d96e8d6..20e0020 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 3a89ac8..69756b3 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index 378d9af..1566760 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 50 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index 6984c40..6068285 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index e44800b..77647e1 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index a992737..3b22889 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index ff3c13a..8e4b879 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index df25019..8e6d4ff 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index 892eb20..34684a1 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -3,6 +3,9 @@ dataset: TGReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..9ad0a9d 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..ab00260 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index f087ca3..c3608e5 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..09d47c3 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..5d11ca1 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..33b2e1a 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..49c7940 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -24,6 +24,7 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 08a96aa..cfd5cf9 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 74382ff..67c6411 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index f4ace12..87edf15 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 1 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index bcfb217..cba24ed 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 8e983a0..31dd6ad 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index 0e1c956..ef5d57a 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -4,6 +4,10 @@ tokenize: rec: bert conv: gpt2 policy: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -13,6 +17,10 @@ scale: 1 rec_model: TGRec conv_model: TGConv policy_model: TGPolicy +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +policy_pretrained_path: 'data/model/pretrain/bert/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 78e5c58..284aa86 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ConvBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 7cd78ec..8726cec 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 87bb5e6..8e8b50b 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 39f9ae8..08068a9 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ProfileBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index c3a5253..aed3b69 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: TopicBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index 0d4250a..fcb981c 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 22ff335..864ed06 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index d2d9d18..9a854fd 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 4b59696..fcc40f5 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index be5fa53..820d894 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 717a2ab..3ac3319 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 94a5f6a..233f43f 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 0d80c59..ca66dd7 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index 8ef81fe..5488b5e 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index b4900b9..809202b 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 7b707e7..21fc6ca 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 7caf3d0..14fa628 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index 3131e0a..f1b03c2 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index bf77cd6..768d369 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index 4c9a821..cea0dce 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index ebaf2c9..c88d0c1 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index b0cbec9..2afc85e 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 66c9ef7..c8e6230 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index 15ba15e..bcf5e8b 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 243a646..3ec5786 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index d79ff24..51f5e6c 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index ba4c02d..42a8edf 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index add69ec..7df885a 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index 9888002..c8c3353 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/textcnn/tgredial.yaml b/config/recommendation/textcnn/tgredial.yaml index 0d5c708..0de66df 100644 --- a/config/recommendation/textcnn/tgredial.yaml +++ b/config/recommendation/textcnn/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: - rec: sougou + rec: jieba # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 33bea19..7a4ad30 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -70,7 +70,7 @@ } -def get_dataset(opt, tokenize, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: """get and process dataset Args: @@ -85,7 +85,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, restore, save) + return dataset_register_table[dataset](opt, tokenize, restore, save, task) else: raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..d06727a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" DuRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class DuRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -94,14 +112,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -262,3 +301,111 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + cnt = 0 + tok2ind = {} + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + tokenizer = self.tokenize + crstokenize = self.crstokenizer + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + match_list = list(set(match_list)) + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..bd2348b 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,63 +8,58 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'jieba': { - 'version': '0.3', + 'resource':{ + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', - 'durecdial_jieba.zip', - 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', + 'durecdial.zip', + '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'jieba': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', - 'durecdial_bert.zip', - '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', - 'durecdial_gpt2.zip', - 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..07b553d 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" GoRecDial ========= @@ -21,14 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.data.dataset.tokenize import CrsTokenize class GoRecDialDataset(BaseDataset): """ @@ -55,7 +62,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +72,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +113,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +305,129 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..5ea42c1 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,61 +8,57 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', - 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', + 'gorecdial.zip', + '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'nltk': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', - 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', - 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + }, + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..190826d 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" Inspired ======== @@ -21,13 +26,16 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class InspiredDataset(BaseDataset): @@ -55,7 +63,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +73,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +114,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -268,3 +308,137 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for genre in dialog['genre']: + list_word = crstokenize.tokenize(genre, tokenizer) + match_list += list_word + + for people in dialog['people']: + list_word = crstokenize.tokenize(people, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..c2d1e75 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', - 'inspired_nltk.zip', - '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', + 'inspired.zip', + '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', - 'inspired_bert.zip', - '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', - 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..66fadb9 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" OpenDialKG ========== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class OpenDialKGDataset(BaseDataset): @@ -56,7 +64,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -271,3 +311,130 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + for item in dialog['item']: + list_word = crstokenize.tokenize(item, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + np.save(word2vec_path, word2embedding) + diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..e5682fe 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', - 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', + 'opendialkg.zip', + '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', - 'opendialkg_bert.zip', - '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', - 'opendialkg_gpt2.zip', - 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..e75c52c 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" ReDial ====== @@ -22,13 +27,16 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class ReDialDataset(BaseDataset): @@ -56,7 +64,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +74,21 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +115,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +306,128 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..170dd3b 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.31', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', - 'redial_nltk.zip', - '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', + 'redial.zip', + '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, - }, - 'bert': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', - 'redial_bert.zip', - 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', - 'redial_gpt2.zip', - '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', - ), - 'special_token_idx': { - 'pad': -100, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..92506f7 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,64 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'pkuseg': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', - 'tgredial_pkuseg.zip', - '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', + 'tgredial.zip', + '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'pkuseg':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', - 'tgredial_bert.zip', - 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', - 'tgredial_gpt2.zip', - '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..c935029 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" TGReDial ======== @@ -23,12 +28,15 @@ from collections import defaultdict from copy import copy import numpy as np +import gensim + from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.data.dataset.tokenize import CrsTokenize class TGReDialDataset(BaseDataset): @@ -59,7 +67,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -69,11 +77,24 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.crstokenizer = CrsTokenize(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'tgredial') + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) @@ -111,14 +132,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -340,3 +382,132 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + each_dict['conv_id'] = each['conv_id'] + for one in each['messages']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text, tokenizer) + one['text'] = list_text + each_data.append(one) + each_dict['messages'] = each_data + each_dict['user_id'] = each['user_id'] + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['messages'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['messages']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word, tokenizer) + match_list += list_word + + for movie in dialog['movie']: + list_word = crstokenize.tokenize(movie, tokenizer) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity, tokenizer) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['messages']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/tokenize.py b/crslab/data/dataset/tokenize.py new file mode 100644 index 0000000..c63352f --- /dev/null +++ b/crslab/data/dataset/tokenize.py @@ -0,0 +1,42 @@ +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +from nltk import word_tokenize +from transformers import AutoTokenizer +import pkuseg +import nltk +import jieba + +class CrsTokenize: + + def __init__(self, path=None) -> None: + self.path = path + + if path is not None: + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + + def tokenize(self, text, tokenizer): + tokenize_fun = getattr(self, tokenizer + '_tokenize') + return tokenize_fun(text) + + def nltk_tokenize(self, text): + # nltk.download('punkt') + return word_tokenize(text) + + def bert_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def gpt2_tokenize(self, text): + return self.my_tokenizer.tokenize(text) + + def pkuseg_tokenize(self, text): + if not hasattr(self, 'pkuseg_tokenizer'): + self.pkuseg_tokenizer = pkuseg.pkuseg() + return self.pkuseg_tokenizer.cut(text) + + def jieba_tokenize(self, text): + split_text = jieba.cut(text) + text_list = ' '.join(split_text).split() + return text_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..b682e42 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,11 +8,16 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '0.2', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', @@ -20,7 +25,7 @@ ) }, 'en': { - 'version': '0.2', + 'version': '.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f08d121 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import time from collections import defaultdict @@ -83,9 +88,10 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + if len(ref_embs[0]) > 0: + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index c93badb..5e84a8c 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" GPT2 ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class GPT2Model(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(GPT2Model, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(GPT2Model, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 99e7ca9..4286f9a 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,15 +2,19 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import os +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com +import os +import json import torch from transformers import GPT2LMHeadModel from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -39,10 +43,9 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(InspiredConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(InspiredConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model for seeker and recommender separately""" @@ -68,17 +71,27 @@ def converse(self, batch, mode): past = None lm_logits_all = [] + if self.language == 'zh': + config_json = os.path.join(GPT2_ZH_PATH, 'config.json') + elif self.language == 'en': + config_json = os.path.join(GPT2_EN_PATH, 'config.json') + + with open(config_json, 'r', encoding='utf-8') as f: + json_config = json.load(f) + + support_up_limits = json_config['n_ctx'] + if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 67948f5..2b2e94b 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class InspiredRecModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(InspiredRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(InspiredRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 57590f4..1230ec8 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" KGSF ==== @@ -33,7 +38,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG -from .resources import resources class KGSFModel(BaseModel): @@ -116,10 +120,9 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] - super(KGSFModel, self).__init__(opt, device, dpath, resource) + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + super(KGSFModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/kgsf/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 0f971b4..ef85782 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,6 +3,10 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" NTRD @@ -29,7 +33,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection -from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -87,12 +90,11 @@ def __init__(self, opt, device, vocab, side_data): # self.n_movies_label = opt['n_movies_label'] self.n_movies_label = 64362 # the number of entity2id # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, dpath, resource) + super(NTRDModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/ntrd/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index a726524..f202dcb 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +76,7 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) @@ -104,7 +109,7 @@ def forward(self, context, utterance_lengths, dialog_lengths): # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -144,7 +149,7 @@ def forward(self, request, request_lengths, context_state): sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 9e505d5..7a6a81c 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Conv ============= @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGConvModel(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) - super(TGConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(TGConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 708b7f9..6986be5 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Policy =============== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGPolicyModel(BaseModel): @@ -44,10 +48,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGPolicyModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TGPolicyModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index a02ac5b..ad185e7 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Rec ============ @@ -28,7 +33,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from crslab.model.recommendation.sasrec.modules import SASRec @@ -68,10 +72,9 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(TGRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 76101cc..117d760 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Conv_BERT ========= @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from ...pretrained_models import resources class ConvBERTModel(BaseModel): @@ -48,10 +52,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ConvBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ConvBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 65b400f..d7cbced 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Profile_BERT ============ @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class ProfileBERTModel(BaseModel): @@ -52,10 +56,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ProfileBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 400eaeb..b20d11a 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Topic_BERT ========== @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TopicBERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - resource = resources['bert'][language] - super(TopicBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TopicBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py deleted file mode 100644 index 33c20d6..0000000 --- a/crslab/model/pretrained_models.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2021/1/6 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2021/1/7 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -"""Download links of pretrain models. - -Now we provide the following models: - -- `BERT`_: zh, en -- `GPT2`_: zh, en - -.. _BERT: - https://www.aclweb.org/anthology/N19-1423/ -.. _GPT2: - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf - -""" - -resources = { - 'bert': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', - 'bert_zh.zip', - 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', - 'bert_en.zip', - '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' - ) - }, - }, - 'gpt2': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', - 'gpt2_zh.zip', - '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', - 'gpt2_en.zip', - '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' - ) - } - } -} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index cb78a7b..a053eea 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class BERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(BERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(BERTModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..2199396 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -34,7 +34,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data, task=None) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -53,7 +53,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - dataset = get_dataset(config, tokenize, restore_data, save_data) + dataset = get_dataset(config, tokenize, restore_data, save_data, task) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..bb839a5 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -154,6 +159,8 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b..96251c5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -169,6 +174,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + bert_param = list(self.rec_model.bert.named_parameters()) else: bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] diff --git a/requirements.txt b/requirements.txt index f7fba73..05950a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 +gensim From 661824a9aa9a415905abe737f123602f356b5a5e Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 30 Sep 2022 12:53:49 +0800 Subject: [PATCH 05/35] txy77 --- config/conversation/gpt2/durecdial.yaml | 4 + config/conversation/gpt2/gorecdial.yaml | 4 + config/conversation/gpt2/inspired.yaml | 4 + config/conversation/gpt2/opendialkg.yaml | 4 + config/conversation/gpt2/redial.yaml | 4 + config/conversation/gpt2/tgredial.yaml | 4 + config/crs/inspired/durecdial.yaml | 7 + config/crs/inspired/gorecdial.yaml | 7 + config/crs/inspired/inspired.yaml | 7 + config/crs/inspired/opendialkg.yaml | 7 + config/crs/inspired/redial.yaml | 7 + config/crs/inspired/tgredial.yaml | 7 + config/crs/kgsf/durecdial.yaml | 1 + config/crs/kgsf/gorecdial.yaml | 1 + config/crs/kgsf/inspired.yaml | 1 + config/crs/kgsf/opendialkg.yaml | 1 + config/crs/kgsf/redial.yaml | 1 + config/crs/kgsf/tgredial.yaml | 1 + config/crs/ntrd/tgredial.yaml | 1 + config/crs/tgredial/durecdial.yaml | 6 + config/crs/tgredial/gorecdial.yaml | 6 + config/crs/tgredial/inspired.yaml | 6 + config/crs/tgredial/opendialkg.yaml | 6 + config/crs/tgredial/redial.yaml | 6 + config/crs/tgredial/tgredial.yaml | 8 + config/policy/conv_bert/tgredial.yaml | 4 + config/policy/mgcg/tgredial.yaml | 2 + config/policy/pmi/tgredial.yaml | 2 + config/policy/profile_bert/tgredial.yaml | 4 + config/policy/topic_bert/tgredial.yaml | 4 + config/recommendation/bert/durecdial.yaml | 4 + config/recommendation/bert/gorecdial.yaml | 4 + config/recommendation/bert/inspired.yaml | 4 + config/recommendation/bert/opendialkg.yaml | 4 + config/recommendation/bert/redial.yaml | 4 + config/recommendation/bert/tgredial.yaml | 4 + config/recommendation/gru4rec/durecdial.yaml | 2 + config/recommendation/gru4rec/gorecdial.yaml | 2 + config/recommendation/gru4rec/inspired.yaml | 2 + config/recommendation/gru4rec/opendialkg.yaml | 2 + config/recommendation/gru4rec/redial.yaml | 2 + config/recommendation/gru4rec/tgredial.yaml | 2 + .../recommendation/popularity/durecdial.yaml | 2 + .../recommendation/popularity/gorecdial.yaml | 2 + .../recommendation/popularity/inspired.yaml | 2 + .../recommendation/popularity/opendialkg.yaml | 2 + config/recommendation/popularity/redial.yaml | 2 + .../recommendation/popularity/tgredial.yaml | 2 + config/recommendation/sasrec/durecdial.yaml | 2 + config/recommendation/sasrec/gorecdial.yaml | 2 + config/recommendation/sasrec/inspired.yaml | 2 + config/recommendation/sasrec/opendialkg.yaml | 2 + config/recommendation/sasrec/redial.yaml | 2 + config/recommendation/sasrec/tgredial.yaml | 2 + config/recommendation/textcnn/tgredial.yaml | 2 +- crslab/data/__init__.py | 4 +- crslab/data/dataset/durecdial/durecdial.py | 166 ++++++++++++++- crslab/data/dataset/durecdial/resources.py | 99 +++++---- crslab/data/dataset/gorecdial/gorecdial.py | 184 ++++++++++++++++- crslab/data/dataset/gorecdial/resources.py | 96 +++++---- crslab/data/dataset/inspired/inspired.py | 191 +++++++++++++++++- crslab/data/dataset/inspired/resources.py | 89 ++++---- crslab/data/dataset/opendialkg/opendialkg.py | 184 ++++++++++++++++- crslab/data/dataset/opendialkg/resources.py | 89 ++++---- crslab/data/dataset/redial/redial.py | 182 ++++++++++++++++- crslab/data/dataset/redial/resources.py | 89 ++++---- crslab/data/dataset/tgredial/resources.py | 99 +++++---- crslab/data/dataset/tgredial/tgredial.py | 188 ++++++++++++++++- crslab/evaluator/embeddings.py | 9 +- crslab/evaluator/standard.py | 12 +- crslab/model/conversation/gpt2/gpt2.py | 13 +- crslab/model/crs/inspired/inspired_conv.py | 26 ++- crslab/model/crs/inspired/inspired_rec.py | 13 +- crslab/model/crs/kgsf/kgsf.py | 13 +- crslab/model/crs/kgsf/resources.py | 62 ------ crslab/model/crs/ntrd/ntrd.py | 12 +- crslab/model/crs/ntrd/resources.py | 62 ------ crslab/model/crs/redial/modules.py | 11 +- crslab/model/crs/tgredial/tg_conv.py | 13 +- crslab/model/crs/tgredial/tg_policy.py | 13 +- crslab/model/crs/tgredial/tg_rec.py | 13 +- crslab/model/policy/conv_bert/conv_bert.py | 13 +- .../model/policy/profile_bert/profile_bert.py | 13 +- crslab/model/policy/topic_bert/topic_bert.py | 13 +- crslab/model/pretrained_models.py | 64 ------ crslab/model/recommendation/bert/bert.py | 13 +- crslab/quick_start/quick_start.py | 4 +- crslab/system/kgsf.py | 7 + crslab/system/tgredial.py | 7 + crslab/tokenizer/base.py | 17 ++ crslab/tokenizer/bert.py | 16 ++ crslab/tokenizer/gpt2.py | 16 ++ crslab/tokenizer/jieba.py | 17 ++ crslab/tokenizer/nltk.py | 16 ++ crslab/tokenizer/pkuseg.py | 16 ++ requirements.txt | 1 + 96 files changed, 1757 insertions(+), 597 deletions(-) delete mode 100644 crslab/model/crs/kgsf/resources.py delete mode 100644 crslab/model/crs/ntrd/resources.py delete mode 100644 crslab/model/pretrained_models.py create mode 100644 crslab/tokenizer/base.py create mode 100644 crslab/tokenizer/bert.py create mode 100644 crslab/tokenizer/gpt2.py create mode 100644 crslab/tokenizer/jieba.py create mode 100644 crslab/tokenizer/nltk.py create mode 100644 crslab/tokenizer/pkuseg.py diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 92a5329..05f568e 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index ea155c4..abedfcb 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index b620579..69a2208 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index d96e8d6..20e0020 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 3a89ac8..69756b3 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index 378d9af..1566760 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: conv: gpt2 +# tokenize path +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # optim conv: epoch: 50 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index 6984c40..6068285 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index e44800b..77647e1 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index a992737..3b22889 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index ff3c13a..8e4b879 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index df25019..8e6d4ff 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +conv_pretrained_path: 'data/model/pretrain/bert/en' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/en' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index 892eb20..34684a1 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -3,6 +3,9 @@ dataset: TGReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +14,12 @@ scale: 1 # model # rec rec_model: InspiredRec +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # conv conv_model: InspiredConv +# pretrained path +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index b5e8eff..9ad0a9d 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 0e4ba7e..ab00260 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index f087ca3..c3608e5 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index b9a2b06..09d47c3 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..5d11ca1 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index a120f98..33b2e1a 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -21,6 +21,7 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 44a1c77..49c7940 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -24,6 +24,7 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' +copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 08a96aa..cfd5cf9 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -3,6 +3,9 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 74382ff..67c6411 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -3,6 +3,9 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index f4ace12..87edf15 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -3,6 +3,9 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 1 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index bcfb217..cba24ed 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -3,6 +3,9 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 8e983a0..31dd6ad 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -3,6 +3,9 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' +conv_tokenize_path: 'data/model/pretrain/gpt2/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,6 +14,9 @@ scale: 0.01 # model rec_model: TGRec conv_model: TGConv +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'data/model/pretrain/gpt2/en' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index 0e1c956..ef5d57a 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -4,6 +4,10 @@ tokenize: rec: bert conv: gpt2 policy: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' +conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -13,6 +17,10 @@ scale: 1 rec_model: TGRec conv_model: TGConv policy_model: TGPolicy +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' +conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +policy_pretrained_path: 'data/model/pretrain/bert/zh' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 78e5c58..284aa86 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ConvBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 7cd78ec..8726cec 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 87bb5e6..8e8b50b 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 39f9ae8..08068a9 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: ProfileBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index c3a5253..aed3b69 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: policy: bert +# tokenize path +policy_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model policy_model: TopicBERT +# pretrained path +policy_pretrained_path: 'data/model/pretrain/bert/zh' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index 0d4250a..fcb981c 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 22ff335..864ed06 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index d2d9d18..9a854fd 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 4b59696..fcc40f5 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index be5fa53..820d894 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 0.01 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/en' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 717a2ab..3ac3319 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 @@ -9,6 +11,8 @@ item_truncate: 100 scale: 1 # model rec_model: BERT +# pretrained path +rec_pretrained_path: 'data/model/pretrain/bert/zh' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 94a5f6a..233f43f 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 0d80c59..ca66dd7 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index 8ef81fe..5488b5e 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index b4900b9..809202b 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 7b707e7..21fc6ca 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 7caf3d0..14fa628 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index 3131e0a..f1b03c2 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index bf77cd6..768d369 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index 4c9a821..cea0dce 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index ebaf2c9..c88d0c1 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index b0cbec9..2afc85e 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 66c9ef7..c8e6230 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index 15ba15e..bcf5e8b 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -2,6 +2,8 @@ dataset: DuRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 243a646..3ec5786 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -2,6 +2,8 @@ dataset: GoRecDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index d79ff24..51f5e6c 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -2,6 +2,8 @@ dataset: Inspired tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index ba4c02d..42a8edf 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -2,6 +2,8 @@ dataset: OpenDialKG tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index add69ec..7df885a 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -2,6 +2,8 @@ dataset: ReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/en' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index 9888002..c8c3353 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -2,6 +2,8 @@ dataset: TGReDial tokenize: rec: bert +# tokenize path +rec_tokenize_path: 'data/model/pretrain/bert/zh' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/textcnn/tgredial.yaml b/config/recommendation/textcnn/tgredial.yaml index 0d5c708..0de66df 100644 --- a/config/recommendation/textcnn/tgredial.yaml +++ b/config/recommendation/textcnn/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: - rec: sougou + rec: jieba # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 33bea19..7a4ad30 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -70,7 +70,7 @@ } -def get_dataset(opt, tokenize, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: """get and process dataset Args: @@ -85,7 +85,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, restore, save) + return dataset_register_table[dataset](opt, tokenize, restore, save, task) else: raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index ded2da6..60ff520 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" DuRecDial ========= @@ -21,14 +26,20 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class DuRecDialDataset(BaseDataset): """ @@ -55,7 +66,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ Args: @@ -65,10 +76,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -94,14 +117,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -262,3 +306,111 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + cnt = 0 + tok2ind = {} + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + tokenizer = self.tokenize + crstokenize = self.crstokenizer + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + for item in dialog['item']: + list_word = crstokenize.tokenize(item) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + match_list = list(set(match_list)) + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 327ccf8..bd2348b 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,63 +8,58 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'jieba': { - 'version': '0.3', + 'resource':{ + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', - 'durecdial_jieba.zip', - 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', + 'durecdial.zip', + '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'jieba': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', - 'durecdial_bert.zip', - '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', - 'durecdial_gpt2.zip', - 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ce9d76..1ee7bab 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" GoRecDial ========= @@ -21,14 +26,20 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources - +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class GoRecDialDataset(BaseDataset): """ @@ -55,7 +66,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +76,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'gorecdial', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +118,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +310,129 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index b31e194..5ea42c1 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,61 +8,57 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESM_Wc7sbAlOgZWo_6lOx34B6mboskdpNdB7FLuyXUET2A?download=1', - 'gorecdial_nltk.zip', - '7e523f7ca90bb32ee8f2471ac5736717c45b20822c63bd958d0546de0a9cd863', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', + 'gorecdial.zip', + '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'nltk': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcTG05imCYpFiBarVfnsAfkBVsbq1iPw23CYcp9kYE9X4g?download=1', - 'gorecdial_bert.zip', - 'fc7aff18504f750d8974d90f2941a01ff22cc054283124936b778ba91f03554f' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Edg4_nbKA49HnQPcd65gPdoBALPADQd4V5qVqOrUub2m9w?download=1', - 'gorecdial_gpt2.zip', - '7234138dcc27ed00bdac95da4096cd435023c229d227fa494d2bd7a653a492a9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } }, - } + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + }, + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73930f1..d3e5991 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" Inspired ======== @@ -21,13 +26,20 @@ import json import os from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class InspiredDataset(BaseDataset): @@ -55,7 +67,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -65,10 +77,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'inspired', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -95,14 +119,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -268,3 +313,137 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + + for genre in dialog['genre']: + list_word = crstokenize.tokenize(genre) + match_list += list_word + + for people in dialog['people']: + list_word = crstokenize.tokenize(people) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index afb0cb1..c2d1e75 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', - 'inspired_nltk.zip', - '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', + 'inspired.zip', + '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', - 'inspired_bert.zip', - '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', - 'inspired_gpt2.zip', - '23bb4ce3299186630fdf673e17f43ee43e91573ea786c922e3527e4c341a313c' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8582705..b38a60e 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" OpenDialKG ========== @@ -22,13 +27,20 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class OpenDialKGDataset(BaseDataset): @@ -56,7 +68,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +78,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +120,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -271,3 +316,130 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + + for item in dialog['item']: + list_word = crstokenize.tokenize(item) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + np.save(word2vec_path, word2embedding) + diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e00ddfc..e5682fe 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', - 'opendialkg_nltk.zip', - '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', + 'opendialkg.zip', + '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', - 'opendialkg_bert.zip', - '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', - 'opendialkg_gpt2.zip', - 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..3a0d603 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" ReDial ====== @@ -22,13 +27,20 @@ import os from collections import defaultdict from copy import copy +import numpy as np +import gensim from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class ReDialDataset(BaseDataset): @@ -56,7 +68,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -66,10 +78,22 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - dpath = os.path.join(DATASET_PATH, "redial", tokenize) + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) def _load_data(self): @@ -96,14 +120,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -266,3 +311,128 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['dialog'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['dialog']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + for movie in dialog['movies']: + list_word = crstokenize.tokenize(movie) + match_list += list_word + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['dialog']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index b347029..170dd3b 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,59 +8,54 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'nltk': { - 'version': '0.31', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', - 'redial_nltk.zip', - '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', + 'redial.zip', + '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 + 'nltk':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, - }, - 'bert': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', - 'redial_bert.zip', - 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } }, - 'gpt2': { - 'version': '0.31', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', - 'redial_gpt2.zip', - '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', - ), - 'special_token_idx': { - 'pad': -100, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 0f37d97..92506f7 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,64 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { - 'pkuseg': { - 'version': '0.3', + 'resource': { + 'version': '1.0', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', - 'tgredial_pkuseg.zip', - '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', + 'tgredial.zip', + '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'pkuseg':{ + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, - }, - 'bert': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', - 'tgredial_bert.zip', - 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 + 'bert': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'gpt2': { + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } }, - 'gpt2': { - 'version': '0.3', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', - 'tgredial_gpt2.zip', - '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' - ), - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 90e03e3..a231d7e 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail +# UPDATE +# @Time : 2022/9/26 +# @Author : Xinyu Tang +# @email : txy20010310@163.com + r""" TGReDial ======== @@ -23,12 +28,19 @@ from collections import defaultdict from copy import copy import numpy as np +import gensim + from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH +from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources +from crslab.tokenizer.nltk import nltk_tokenize +from crslab.tokenizer.bert import bert_tokenize +from crslab.tokenizer.gpt2 import gpt2_tokenize +from crslab.tokenizer.jieba import jieba_tokenize +from crslab.tokenizer.pkuseg import pkuseg_tokenize class TGReDialDataset(BaseDataset): @@ -59,7 +71,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False, task=None): """Specify tokenized resource and init base dataset. Args: @@ -69,11 +81,25 @@ def __init__(self, opt, tokenize, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - resource = resources[tokenize] - self.special_token_idx = resource['special_token_idx'] + if 'copy' in opt: + self.copy = True + else: + self.copy = False + resource = resources['resource'] + token = resource[tokenize] + self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + + self.tokenize = tokenize + task_tokenize_path = str(task) + '_tokenize_path' + self.tokenize_path = None + if task_tokenize_path in opt: + self.tokenize_path = opt[task_tokenize_path] + self.tokenize_class = globals()[tokenize + '_tokenize'] + self.crstokenizer = self.tokenize_class(self.tokenize_path) + dpath = os.path.join(DATASET_PATH, 'tgredial') + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) @@ -111,14 +137,35 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + # split token + processing_train_data = self.split_token(train_data) + logger.info("[Finish train data split]") + # generate tok2ind + tok2ind = self.generate_tok2ind(processing_train_data) + logger.info("[Finish generate train tok2ind]") + # generate word2vec + self.generate_word2vec(processing_train_data) + logger.info('[Finish generate word2vec]') + # build copy_mask + if self.copy: + copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + logger.info('[Finish generate copy_mask]') + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + # split_token + processing_valid_data = self.split_token(valid_data) + logger.info("[Finish valid data split]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + # split_token + processing_test_data = self.split_token(test_data) + logger.info("[Finish test data split]") - return train_data, valid_data, test_data + return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) @@ -340,3 +387,132 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } + + def split_token(self, data): + + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + each_dict['conv_id'] = each['conv_id'] + for one in each['messages']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['messages'] = each_data + each_dict['user_id'] = each['user_id'] + all_data.append(each_dict) + + return all_data + + def generate_tok2ind(self, processed_train_data): + + cnt = 0 + tok2ind = {} + + if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + tok2ind['__pad__'] = cnt + cnt += 1 + tok2ind['__start__'] = cnt + cnt += 1 + tok2ind['__end__'] = cnt + cnt += 1 + tok2ind['__unk__'] = cnt + cnt += 1 + elif self.tokenize == 'bert': + tok2ind['[PAD]'] = cnt + cnt += 1 + + for i in tqdm(processed_train_data): + dialog = i['messages'] + for each_dialog in dialog: + text = each_dialog['text'] + for each_word in text: + if each_word not in tok2ind: + tok2ind[each_word] = cnt + cnt += 1 + + if self.tokenize == 'nltk': + tok2ind['_split_'] = cnt + cnt += 1 + + tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') + with open(tok2ind_path, 'w', encoding='utf-8') as write: + json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + + return tok2ind + + def generate_copy_mask(self, tok2ind, processing_train_data): + + tokenizer = self.tokenize + crstokenize = self.crstokenizer + + copy_mask = np.zeros((len(tok2ind)), dtype=bool) + for each_data in tqdm(processing_train_data): + for dialog in each_data['messages']: + match_list = [] + text = dialog['text'] + for word in dialog['word']: + list_word = crstokenize.tokenize(word) + match_list += list_word + + for movie in dialog['movie']: + list_word = crstokenize.tokenize(movie) + match_list += list_word + + for entity in dialog['entity']: + list_word = crstokenize.tokenize(entity) + match_list += list_word + + match_list = list(set(match_list)) + + for each_word in text: + if each_word in match_list: + token_id = tok2ind[each_word] + copy_mask[token_id] = True + + if not os.path.exists(MODEL_PATH): + os.mkdir(MODEL_PATH) + + if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): + os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) + + copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') + if not os.path.exists(copy_mask_dirpath): + os.mkdir(copy_mask_dirpath) + + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + np.save(path, copy_mask) + + + def generate_word2vec(self, processing_train_data): + + corpus = [] + for each_data in processing_train_data: + for dialog in each_data['messages']: + text = dialog['text'] + corpus.append(text) + + model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + + if self.tokenize == 'nltk': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + + elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': + word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + + elif self.tokenize == 'bert': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + + elif self.tokenize == 'gpt2': + word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} + word2embedding = [model.wv[word] for word in word2index] + + word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') + np.save(word2vec_path, word2embedding) diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b7c30fd..b682e42 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,11 +8,16 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '0.2', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', @@ -20,7 +25,7 @@ ) }, 'en': { - 'version': '0.2', + 'version': '.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f08d121 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import time from collections import defaultdict @@ -83,9 +88,10 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + if len(ref_embs[0]) > 0: + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index c93badb..5e84a8c 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" GPT2 ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class GPT2Model(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(GPT2Model, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(GPT2Model, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 99e7ca9..30fe87e 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,15 +2,19 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import os +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com +import os +import json import torch from transformers import GPT2LMHeadModel from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -39,10 +43,9 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) - super(InspiredConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(InspiredConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model for seeker and recommender separately""" @@ -68,17 +71,24 @@ def converse(self, batch, mode): past = None lm_logits_all = [] + config_json = os.path.join(self.dpath, 'config.json') + + with open(config_json, 'r', encoding='utf-8') as f: + json_config = json.load(f) + + support_up_limits = json_config['n_ctx'] + if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0].shape[3] + iter.shape[1] > 1024: + if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 67948f5..2b2e94b 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class InspiredRecModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(InspiredRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(InspiredRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 57590f4..1230ec8 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" KGSF ==== @@ -33,7 +38,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG -from .resources import resources class KGSFModel(BaseModel): @@ -116,10 +120,9 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] - super(KGSFModel, self).__init__(opt, device, dpath, resource) + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + super(KGSFModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/kgsf/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 0f971b4..ef85782 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,6 +3,10 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" NTRD @@ -29,7 +33,6 @@ from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection -from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -87,12 +90,11 @@ def __init__(self, opt, device, vocab, side_data): # self.n_movies_label = opt['n_movies_label'] self.n_movies_label = 64362 # the number of entity2id # copy mask - dataset = opt['dataset'] - dpath = os.path.join(MODEL_PATH, "kgsf", dataset) - resource = resources[dataset] + self.dataset = opt['dataset'] + self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, dpath, resource) + super(NTRDModel, self).__init__(opt, device, self.dpath) def build_model(self): self._init_embeddings() diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py deleted file mode 100644 index d484a3f..0000000 --- a/crslab/model/crs/ntrd/resources.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/12/13 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2020/12/15 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -resources = { - 'ReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', - 'kgsf_redial.zip', - 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', - ), - }, - 'TGReDial': { - 'version': '0.2', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', - 'kgsf_tgredial.zip', - 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', - ), - }, - 'GoRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUfPcGfLHAJPj-F3Mr79CF4Bc5sZXKk-jysutrjiRcQvCg?download=1', - 'kgsf_gorecdial.zip', - '9794abf12b5d6773d867556685da14d951d42f64a5c4781af7d6fb720e87ec4f', - ) - }, - 'OpenDialKG': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', - 'kgsf_opendialkg.zip', - '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' - ) - }, - 'Inspired': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', - 'kgsf_inspired.zip', - '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' - ) - }, - 'DuRecDial': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', - 'kgsf_durecdial.zip', - 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' - ) - } -} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index a726524..f202dcb 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +76,7 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) @@ -104,7 +109,7 @@ def forward(self, context, utterance_lengths, dialog_lengths): # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -144,7 +149,7 @@ def forward(self, request, request_lengths, context_state): sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 9e505d5..7a6a81c 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Conv ============= @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGConvModel(BaseModel): @@ -54,10 +58,9 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - language = dataset_language_map[opt['dataset']] - resource = resources['gpt2'][language] - dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) - super(TGConvModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['conv_pretrained_path'] + super(TGConvModel, self).__init__(opt, device, self.dpath) def build_model(self): """build model""" diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 708b7f9..6986be5 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Policy =============== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TGPolicyModel(BaseModel): @@ -44,10 +48,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGPolicyModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TGPolicyModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index a02ac5b..ad185e7 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" TGReDial_Rec ============ @@ -28,7 +33,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources from crslab.model.recommendation.sasrec.modules import SASRec @@ -68,10 +72,9 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(TGRecModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(TGRecModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 76101cc..117d760 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Conv_BERT ========= @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from ...pretrained_models import resources class ConvBERTModel(BaseModel): @@ -48,10 +52,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ConvBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ConvBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 65b400f..d7cbced 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Profile_BERT ============ @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class ProfileBERTModel(BaseModel): @@ -52,10 +56,9 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(ProfileBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 400eaeb..b20d11a 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" Topic_BERT ========== @@ -26,7 +31,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class TopicBERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - language = dataset_language_map[opt['dataset']] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - resource = resources['bert'][language] - super(TopicBERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['policy_pretrained_path'] + super(TopicBERTModel, self).__init__(opt, device, self.dpath) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py deleted file mode 100644 index 33c20d6..0000000 --- a/crslab/model/pretrained_models.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2021/1/6 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -# UPDATE -# @Time : 2021/1/7 -# @Author : Xiaolei Wang -# @email : wxl1999@foxmail.com - -from crslab.download import DownloadableFile - -"""Download links of pretrain models. - -Now we provide the following models: - -- `BERT`_: zh, en -- `GPT2`_: zh, en - -.. _BERT: - https://www.aclweb.org/anthology/N19-1423/ -.. _GPT2: - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf - -""" - -resources = { - 'bert': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', - 'bert_zh.zip', - 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', - 'bert_en.zip', - '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' - ) - }, - }, - 'gpt2': { - 'zh': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', - 'gpt2_zh.zip', - '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' - ) - }, - 'en': { - 'version': '0.1', - 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', - 'gpt2_en.zip', - '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' - ) - } - } -} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index cb78a7b..a053eea 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + r""" BERT ==== @@ -27,7 +32,6 @@ from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map from crslab.model.base import BaseModel -from crslab.model.pretrained_models import resources class BERTModel(BaseModel): @@ -50,10 +54,9 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - language = dataset_language_map[opt['dataset']] - resource = resources['bert'][language] - dpath = os.path.join(PRETRAIN_PATH, "bert", language) - super(BERTModel, self).__init__(opt, device, dpath, resource) + self.language = dataset_language_map[opt['dataset']] + self.dpath = opt['rec_pretrained_path'] + super(BERTModel, self).__init__(opt, device, self.dpath) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..2199396 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -34,7 +34,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data, task=None) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -53,7 +53,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - dataset = get_dataset(config, tokenize, restore_data, save_data) + dataset = get_dataset(config, tokenize, restore_data, save_data, task) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..bb839a5 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,6 +7,11 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -154,6 +159,8 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b..96251c5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,6 +7,11 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com +# UPDATE: +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + import os import torch @@ -169,6 +174,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + bert_param = list(self.rec_model.bert.named_parameters()) else: bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] diff --git a/crslab/tokenizer/base.py b/crslab/tokenizer/base.py new file mode 100644 index 0000000..153a3e3 --- /dev/null +++ b/crslab/tokenizer/base.py @@ -0,0 +1,17 @@ +# @Time : 2022/9/30 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +from transformers import AutoTokenizer + +class BaseCrsTokenize: + + def __init__(self, path=None) -> None: + pass + + def tokenize(self, text): + ''' + split token + ''' + pass \ No newline at end of file diff --git a/crslab/tokenizer/bert.py b/crslab/tokenizer/bert.py new file mode 100644 index 0000000..eb77663 --- /dev/null +++ b/crslab/tokenizer/bert.py @@ -0,0 +1,16 @@ +# @Time : 2022/9/30 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +from transformers import AutoTokenizer + +from crslab.tokenizer.base import BaseCrsTokenize + +class bert_tokenize(BaseCrsTokenize): + + def __init__(self, path=None) -> None: + super().__init__(path) + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + + def tokenize(self, text): + return self.my_tokenizer.tokenize(text) \ No newline at end of file diff --git a/crslab/tokenizer/gpt2.py b/crslab/tokenizer/gpt2.py new file mode 100644 index 0000000..117309a --- /dev/null +++ b/crslab/tokenizer/gpt2.py @@ -0,0 +1,16 @@ +# @Time : 2022/9/28 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +from transformers import AutoTokenizer + +from crslab.tokenizer.base import BaseCrsTokenize + +class gpt2_tokenize(BaseCrsTokenize): + + def __init__(self, path=None) -> None: + super().__init__(path) + self.my_tokenizer = AutoTokenizer.from_pretrained(path) + + def tokenize(self, text): + return self.my_tokenizer.tokenize(text) \ No newline at end of file diff --git a/crslab/tokenizer/jieba.py b/crslab/tokenizer/jieba.py new file mode 100644 index 0000000..e931205 --- /dev/null +++ b/crslab/tokenizer/jieba.py @@ -0,0 +1,17 @@ +# @Time : 2022/9/30 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import jieba + +from crslab.tokenizer.base import BaseCrsTokenize + +class jieba_tokenize(BaseCrsTokenize): + + def __init__(self, path=None) -> None: + super().__init__(path) + + def tokenize(self, text): + split_text = jieba.cut(text) + text_list = ' '.join(split_text).split() + return text_list \ No newline at end of file diff --git a/crslab/tokenizer/nltk.py b/crslab/tokenizer/nltk.py new file mode 100644 index 0000000..231d73b --- /dev/null +++ b/crslab/tokenizer/nltk.py @@ -0,0 +1,16 @@ +# @Time : 2022/9/30 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +from nltk import word_tokenize + +from crslab.tokenizer.base import BaseCrsTokenize + +class nltk_tokenize(BaseCrsTokenize): + + def __init__(self, path=None) -> None: + super().__init__(path) + + def tokenize(self, text): + # nltk.download('punkt') + return word_tokenize(text) \ No newline at end of file diff --git a/crslab/tokenizer/pkuseg.py b/crslab/tokenizer/pkuseg.py new file mode 100644 index 0000000..99876fe --- /dev/null +++ b/crslab/tokenizer/pkuseg.py @@ -0,0 +1,16 @@ +# @Time : 2022/9/30 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import pkuseg + +from crslab.tokenizer.base import BaseCrsTokenize + +class pkuseg_tokenize(BaseCrsTokenize): + + def __init__(self, path=None) -> None: + self.pkuseg_tokenizer = pkuseg.pkuseg() + super().__init__(path) + + def tokenize(self, text): + return self.pkuseg_tokenizer.cut(text) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f7fba73..05950a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 +gensim From 3226e5ea44c2b7bf8b52ac599c3c7d3c4c4aa58f Mon Sep 17 00:00:00 2001 From: txy77 Date: Thu, 6 Oct 2022 19:11:41 +0800 Subject: [PATCH 06/35] txy77 --- config/conversation/gpt2/durecdial.yaml | 4 +- config/conversation/gpt2/gorecdial.yaml | 4 +- config/conversation/gpt2/inspired.yaml | 4 +- config/conversation/gpt2/opendialkg.yaml | 4 +- config/conversation/gpt2/redial.yaml | 4 +- config/conversation/gpt2/tgredial.yaml | 4 +- config/crs/inspired/durecdial.yaml | 8 ++-- config/crs/inspired/gorecdial.yaml | 8 ++-- config/crs/inspired/inspired.yaml | 8 ++-- config/crs/inspired/opendialkg.yaml | 8 ++-- config/crs/inspired/redial.yaml | 8 ++-- config/crs/inspired/tgredial.yaml | 8 ++-- config/crs/tgredial/durecdial.yaml | 8 ++-- config/crs/tgredial/gorecdial.yaml | 8 ++-- config/crs/tgredial/inspired.yaml | 8 ++-- config/crs/tgredial/opendialkg.yaml | 8 ++-- config/crs/tgredial/redial.yaml | 8 ++-- config/crs/tgredial/tgredial.yaml | 12 +++--- config/policy/conv_bert/tgredial.yaml | 4 +- config/policy/mgcg/tgredial.yaml | 2 +- config/policy/pmi/tgredial.yaml | 2 +- config/policy/profile_bert/tgredial.yaml | 4 +- config/policy/topic_bert/tgredial.yaml | 4 +- config/recommendation/bert/durecdial.yaml | 4 +- config/recommendation/bert/gorecdial.yaml | 4 +- config/recommendation/bert/inspired.yaml | 4 +- config/recommendation/bert/opendialkg.yaml | 4 +- config/recommendation/bert/redial.yaml | 4 +- config/recommendation/bert/tgredial.yaml | 4 +- config/recommendation/gru4rec/durecdial.yaml | 2 +- config/recommendation/gru4rec/gorecdial.yaml | 2 +- config/recommendation/gru4rec/inspired.yaml | 2 +- config/recommendation/gru4rec/opendialkg.yaml | 2 +- config/recommendation/gru4rec/redial.yaml | 2 +- config/recommendation/gru4rec/tgredial.yaml | 2 +- .../recommendation/popularity/durecdial.yaml | 2 +- .../recommendation/popularity/gorecdial.yaml | 2 +- .../recommendation/popularity/inspired.yaml | 2 +- .../recommendation/popularity/opendialkg.yaml | 2 +- config/recommendation/popularity/redial.yaml | 2 +- .../recommendation/popularity/tgredial.yaml | 2 +- config/recommendation/sasrec/durecdial.yaml | 2 +- config/recommendation/sasrec/gorecdial.yaml | 2 +- config/recommendation/sasrec/inspired.yaml | 2 +- config/recommendation/sasrec/opendialkg.yaml | 2 +- config/recommendation/sasrec/redial.yaml | 2 +- config/recommendation/sasrec/tgredial.yaml | 2 +- crslab/data/dataset/durecdial/durecdial.py | 10 ++--- crslab/data/dataset/gorecdial/gorecdial.py | 10 ++--- crslab/data/dataset/inspired/inspired.py | 10 ++--- crslab/data/dataset/opendialkg/opendialkg.py | 10 ++--- crslab/data/dataset/redial/redial.py | 10 ++--- crslab/data/dataset/tgredial/tgredial.py | 10 ++--- crslab/data/dataset/tokenize.py | 42 ------------------- crslab/{ => data/dataset}/tokenizer/base.py | 0 crslab/{ => data/dataset}/tokenizer/bert.py | 2 +- crslab/{ => data/dataset}/tokenizer/gpt2.py | 2 +- crslab/{ => data/dataset}/tokenizer/jieba.py | 2 +- crslab/{ => data/dataset}/tokenizer/nltk.py | 2 +- crslab/{ => data/dataset}/tokenizer/pkuseg.py | 2 +- crslab/model/conversation/gpt2/gpt2.py | 1 - crslab/model/crs/inspired/inspired_conv.py | 10 ++--- crslab/model/crs/inspired/inspired_rec.py | 1 - crslab/model/crs/tgredial/tg_conv.py | 1 - crslab/model/crs/tgredial/tg_policy.py | 1 - crslab/model/crs/tgredial/tg_rec.py | 1 - crslab/model/policy/conv_bert/conv_bert.py | 1 - .../model/policy/profile_bert/profile_bert.py | 1 - crslab/model/policy/topic_bert/topic_bert.py | 1 - crslab/model/recommendation/bert/bert.py | 1 - crslab/system/inspired.py | 1 - crslab/system/redial.py | 1 - 72 files changed, 138 insertions(+), 195 deletions(-) delete mode 100644 crslab/data/dataset/tokenize.py rename crslab/{ => data/dataset}/tokenizer/base.py (100%) rename crslab/{ => data/dataset}/tokenizer/bert.py (85%) rename crslab/{ => data/dataset}/tokenizer/gpt2.py (85%) rename crslab/{ => data/dataset}/tokenizer/jieba.py (84%) rename crslab/{ => data/dataset}/tokenizer/nltk.py (83%) rename crslab/{ => data/dataset}/tokenizer/pkuseg.py (84%) diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index 05f568e..c3aa503 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -3,7 +3,7 @@ dataset: DuRecDial tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +conv_pretrained_path: 'GPT2-chitchat' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index abedfcb..4623943 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -3,7 +3,7 @@ dataset: GoRecDial tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index 69a2208..f150198 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -3,7 +3,7 @@ dataset: Inspired tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index 20e0020..091fa9f 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -3,7 +3,7 @@ dataset: OpenDialKG tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 69756b3..07b4e2b 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -3,7 +3,7 @@ dataset: ReDial tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index 1566760..f747e2e 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: conv: gpt2 # tokenize path -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model conv_model: GPT2 # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +conv_pretrained_path: 'GPT2-chitchat' # optim conv: epoch: 50 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index 6068285..cc8fa80 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +rec_tokenize_path: 'bert-base-chinese' +conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' +rec_pretrained_path: 'bert-base-chinese' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +conv_pretrained_path: 'GPT2-chitchat' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index 77647e1..250edc2 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index 3b22889..0c1887a 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index 8e4b879..eb440d9 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -conv_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index 8e6d4ff..48c3112 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -conv_pretrained_path: 'data/model/pretrain/bert/en' +conv_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index 34684a1..cfca337 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +rec_tokenize_path: 'bert-base-chinese' +conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,11 +15,11 @@ scale: 1 # rec rec_model: InspiredRec # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' +rec_pretrained_path: 'bert-base-chinese' # conv conv_model: InspiredConv # pretrained path -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +conv_pretrained_path: 'GPT2-chitchat' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index cfd5cf9..8085430 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' +rec_tokenize_path: 'bert-base-chinese' +conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,8 +15,8 @@ scale: 0.01 rec_model: TGRec conv_model: TGConv # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' +rec_pretrained_path: 'bert-base-chinese' +conv_pretrained_path: 'GPT2-chitchat' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 67c6411..8e1e982 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,8 +15,8 @@ scale: 0.01 rec_model: TGRec conv_model: TGConv # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +rec_pretrained_path: 'bert-base-uncased' +conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index 87edf15..c7a4a12 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,8 +15,8 @@ scale: 1 rec_model: TGRec conv_model: TGConv # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +rec_pretrained_path: 'bert-base-uncased' +conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index cba24ed..15ee74f 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,8 +15,8 @@ scale: 0.01 rec_model: TGRec conv_model: TGConv # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +rec_pretrained_path: 'bert-base-uncased' +conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index 31dd6ad..ae5f037 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -4,8 +4,8 @@ tokenize: rec: bert conv: gpt2 # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' -conv_tokenize_path: 'data/model/pretrain/gpt2/en' +rec_tokenize_path: 'bert-base-uncased' +conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -15,8 +15,8 @@ scale: 0.01 rec_model: TGRec conv_model: TGConv # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' -conv_pretrained_path: 'data/model/pretrain/gpt2/en' +rec_pretrained_path: 'bert-base-uncased' +conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index ef5d57a..3a8081b 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -5,9 +5,9 @@ tokenize: conv: gpt2 policy: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' -conv_tokenize_path: 'data/model/pretrain/gpt2/zh' -policy_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' +conv_tokenize_path: 'GPT2-chitchat' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -18,9 +18,9 @@ rec_model: TGRec conv_model: TGConv policy_model: TGPolicy # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' -conv_pretrained_path: 'data/model/pretrain/gpt2/zh' -policy_pretrained_path: 'data/model/pretrain/bert/zh' +rec_pretrained_path: 'bert-base-chinese' +conv_pretrained_path: 'GPT2-chitchat' +policy_pretrained_path: 'bert-base-chinese' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 284aa86..393c0b6 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: policy: bert # tokenize path -policy_tokenize_path: 'data/model/pretrain/bert/zh' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model policy_model: ConvBERT # pretrained path -policy_pretrained_path: 'data/model/pretrain/bert/zh' +policy_pretrained_path: 'bert-base-chinese' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 8726cec..5bb42f0 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: policy: bert # tokenize path -policy_tokenize_path: 'data/model/pretrain/bert/zh' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 8e8b50b..7658e86 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: policy: bert # tokenize path -policy_tokenize_path: 'data/model/pretrain/bert/zh' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 08068a9..9a35942 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: policy: bert # tokenize path -policy_tokenize_path: 'data/model/pretrain/bert/zh' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model policy_model: ProfileBERT # pretrained path -policy_pretrained_path: 'data/model/pretrain/bert/zh' +policy_pretrained_path: 'bert-base-chinese' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index aed3b69..6884468 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: policy: bert # tokenize path -policy_tokenize_path: 'data/model/pretrain/bert/zh' +policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model policy_model: TopicBERT # pretrained path -policy_pretrained_path: 'data/model/pretrain/bert/zh' +policy_pretrained_path: 'bert-base-chinese' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index fcb981c..b0edb13 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -3,7 +3,7 @@ dataset: DuRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' +rec_pretrained_path: 'bert-base-chinese' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 864ed06..727f15e 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -3,7 +3,7 @@ dataset: GoRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index 9a854fd..b492f0d 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -3,7 +3,7 @@ dataset: Inspired tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index fcc40f5..7972304 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -3,7 +3,7 @@ dataset: OpenDialKG tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index 820d894..696bd96 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -3,7 +3,7 @@ dataset: ReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 0.01 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/en' +rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 3ac3319..679667d 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -12,7 +12,7 @@ scale: 1 # model rec_model: BERT # pretrained path -rec_pretrained_path: 'data/model/pretrain/bert/zh' +rec_pretrained_path: 'bert-base-chinese' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index 233f43f..ca29808 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -3,7 +3,7 @@ dataset: DuRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index ca66dd7..814ebd6 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -3,7 +3,7 @@ dataset: GoRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index 5488b5e..f2508db 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -3,7 +3,7 @@ dataset: Inspired tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index 809202b..8cd8865 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -3,7 +3,7 @@ dataset: OpenDialKG tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 21fc6ca..1af0f3f 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -3,7 +3,7 @@ dataset: ReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 14fa628..8bac5d8 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index f1b03c2..cb05935 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -3,7 +3,7 @@ dataset: DuRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index 768d369..187d552 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -3,7 +3,7 @@ dataset: GoRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index cea0dce..20714e8 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -3,7 +3,7 @@ dataset: Inspired tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index c88d0c1..cbc319f 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -3,7 +3,7 @@ dataset: OpenDialKG tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index 2afc85e..2265d2c 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -3,7 +3,7 @@ dataset: ReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index c8e6230..973d1e4 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index bcf5e8b..ac1fcfe 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -3,7 +3,7 @@ dataset: DuRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 3ec5786..1c88d58 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -3,7 +3,7 @@ dataset: GoRecDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index 51f5e6c..7c1ab9f 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -3,7 +3,7 @@ dataset: Inspired tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index 42a8edf..c380597 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -3,7 +3,7 @@ dataset: OpenDialKG tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index 7df885a..1809617 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -3,7 +3,7 @@ dataset: ReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/en' +rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index c8c3353..0751b7c 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -3,7 +3,7 @@ dataset: TGReDial tokenize: rec: bert # tokenize path -rec_tokenize_path: 'data/model/pretrain/bert/zh' +rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 60ff520..aaeee1d 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -35,11 +35,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class DuRecDialDataset(BaseDataset): """ diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1ee7bab..e337b10 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -35,11 +35,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class GoRecDialDataset(BaseDataset): """ diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index d3e5991..e7adb19 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -35,11 +35,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class InspiredDataset(BaseDataset): diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index b38a60e..f0b4e55 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -36,11 +36,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class OpenDialKGDataset(BaseDataset): diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 3a0d603..861215f 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -36,11 +36,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class ReDialDataset(BaseDataset): diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index a231d7e..83a850e 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -36,11 +36,11 @@ from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset from .resources import resources -from crslab.tokenizer.nltk import nltk_tokenize -from crslab.tokenizer.bert import bert_tokenize -from crslab.tokenizer.gpt2 import gpt2_tokenize -from crslab.tokenizer.jieba import jieba_tokenize -from crslab.tokenizer.pkuseg import pkuseg_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize +from crslab.data.dataset.tokenizer.bert import bert_tokenize +from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize +from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize class TGReDialDataset(BaseDataset): diff --git a/crslab/data/dataset/tokenize.py b/crslab/data/dataset/tokenize.py deleted file mode 100644 index c63352f..0000000 --- a/crslab/data/dataset/tokenize.py +++ /dev/null @@ -1,42 +0,0 @@ -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -import os -from nltk import word_tokenize -from transformers import AutoTokenizer -import pkuseg -import nltk -import jieba - -class CrsTokenize: - - def __init__(self, path=None) -> None: - self.path = path - - if path is not None: - self.my_tokenizer = AutoTokenizer.from_pretrained(path) - - def tokenize(self, text, tokenizer): - tokenize_fun = getattr(self, tokenizer + '_tokenize') - return tokenize_fun(text) - - def nltk_tokenize(self, text): - # nltk.download('punkt') - return word_tokenize(text) - - def bert_tokenize(self, text): - return self.my_tokenizer.tokenize(text) - - def gpt2_tokenize(self, text): - return self.my_tokenizer.tokenize(text) - - def pkuseg_tokenize(self, text): - if not hasattr(self, 'pkuseg_tokenizer'): - self.pkuseg_tokenizer = pkuseg.pkuseg() - return self.pkuseg_tokenizer.cut(text) - - def jieba_tokenize(self, text): - split_text = jieba.cut(text) - text_list = ' '.join(split_text).split() - return text_list \ No newline at end of file diff --git a/crslab/tokenizer/base.py b/crslab/data/dataset/tokenizer/base.py similarity index 100% rename from crslab/tokenizer/base.py rename to crslab/data/dataset/tokenizer/base.py diff --git a/crslab/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py similarity index 85% rename from crslab/tokenizer/bert.py rename to crslab/data/dataset/tokenizer/bert.py index eb77663..0ee5463 100644 --- a/crslab/tokenizer/bert.py +++ b/crslab/data/dataset/tokenizer/bert.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer -from crslab.tokenizer.base import BaseCrsTokenize +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class bert_tokenize(BaseCrsTokenize): diff --git a/crslab/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py similarity index 85% rename from crslab/tokenizer/gpt2.py rename to crslab/data/dataset/tokenizer/gpt2.py index 117309a..b52747d 100644 --- a/crslab/tokenizer/gpt2.py +++ b/crslab/data/dataset/tokenizer/gpt2.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer -from crslab.tokenizer.base import BaseCrsTokenize +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class gpt2_tokenize(BaseCrsTokenize): diff --git a/crslab/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py similarity index 84% rename from crslab/tokenizer/jieba.py rename to crslab/data/dataset/tokenizer/jieba.py index e931205..a834995 100644 --- a/crslab/tokenizer/jieba.py +++ b/crslab/data/dataset/tokenizer/jieba.py @@ -4,7 +4,7 @@ import jieba -from crslab.tokenizer.base import BaseCrsTokenize +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class jieba_tokenize(BaseCrsTokenize): diff --git a/crslab/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py similarity index 83% rename from crslab/tokenizer/nltk.py rename to crslab/data/dataset/tokenizer/nltk.py index 231d73b..94b318c 100644 --- a/crslab/tokenizer/nltk.py +++ b/crslab/data/dataset/tokenizer/nltk.py @@ -4,7 +4,7 @@ from nltk import word_tokenize -from crslab.tokenizer.base import BaseCrsTokenize +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class nltk_tokenize(BaseCrsTokenize): diff --git a/crslab/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py similarity index 84% rename from crslab/tokenizer/pkuseg.py rename to crslab/data/dataset/tokenizer/pkuseg.py index 99876fe..2a4e8fa 100644 --- a/crslab/tokenizer/pkuseg.py +++ b/crslab/data/dataset/tokenizer/pkuseg.py @@ -4,7 +4,7 @@ import pkuseg -from crslab.tokenizer.base import BaseCrsTokenize +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class pkuseg_tokenize(BaseCrsTokenize): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index 5e84a8c..e394220 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -58,7 +58,6 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['conv_pretrained_path'] super(GPT2Model, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 30fe87e..71eee2e 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -10,7 +10,7 @@ import os import json import torch -from transformers import GPT2LMHeadModel +from transformers import GPT2LMHeadModel, GPT2Config from crslab.config import PRETRAIN_PATH from crslab.data import dataset_language_map @@ -43,7 +43,6 @@ def __init__(self, opt, device, vocab, side_data): self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['conv_pretrained_path'] super(InspiredConvModel, self).__init__(opt, device, self.dpath) @@ -71,12 +70,9 @@ def converse(self, batch, mode): past = None lm_logits_all = [] - config_json = os.path.join(self.dpath, 'config.json') - - with open(config_json, 'r', encoding='utf-8') as f: - json_config = json.load(f) + GPT2_Config = GPT2Config.from_pretrained(self.dpath) - support_up_limits = json_config['n_ctx'] + support_up_limits = GPT2_Config.n_positions if mode != 'test': for turn, iter in enumerate(input_ids_iters): diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 2b2e94b..5efd93d 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -54,7 +54,6 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['rec_pretrained_path'] super(InspiredRecModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 7a6a81c..a98975c 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -58,7 +58,6 @@ def __init__(self, opt, device, vocab, side_data): self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['conv_pretrained_path'] super(TGConvModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 6986be5..92fb0a0 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -48,7 +48,6 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['policy_pretrained_path'] super(TGPolicyModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index ad185e7..4f4d592 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -72,7 +72,6 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['rec_pretrained_path'] super(TGRecModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 117d760..663656b 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -52,7 +52,6 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['policy_pretrained_path'] super(ConvBERTModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index d7cbced..8b4c69a 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -56,7 +56,6 @@ def __init__(self, opt, device, vocab, side_data): self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['policy_pretrained_path'] super(ProfileBERTModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index b20d11a..7a9181b 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -54,7 +54,6 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['policy_pretrained_path'] super(TopicBERTModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index a053eea..7e7eb80 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -54,7 +54,6 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - self.language = dataset_language_map[opt['dataset']] self.dpath = opt['rec_pretrained_path'] super(BERTModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/system/inspired.py b/crslab/system/inspired.py index b9219d1..3fb77ed 100644 --- a/crslab/system/inspired.py +++ b/crslab/system/inspired.py @@ -61,7 +61,6 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps - self.language = dataset_language_map[self.opt['dataset']] def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() diff --git a/crslab/system/redial.py b/crslab/system/redial.py index 0276e90..a7fa758 100644 --- a/crslab/system/redial.py +++ b/crslab/system/redial.py @@ -51,7 +51,6 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.rec_batch_size = self.rec_optim_opt['batch_size'] self.conv_batch_size = self.conv_optim_opt['batch_size'] - self.language = dataset_language_map[self.opt['dataset']] def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() From 99796d0cbd66283e5710ceabfebefc233cc2bf31 Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 7 Oct 2022 11:03:00 +0800 Subject: [PATCH 07/35] txy77 --- .vscode/settings.json | 4 + crslab/data/__init__.py | 6 +- crslab/data/dataset/durecdial/durecdial.py | 108 +++++----- crslab/data/dataset/durecdial/resources.py | 6 +- crslab/data/dataset/gorecdial/gorecdial.py | 153 ++++++------- crslab/data/dataset/gorecdial/resources.py | 4 +- crslab/data/dataset/inspired/inspired.py | 116 +++++----- crslab/data/dataset/inspired/resources.py | 2 +- crslab/data/dataset/opendialkg/opendialkg.py | 133 ++++++------ crslab/data/dataset/opendialkg/resources.py | 4 +- crslab/data/dataset/redial/redial.py | 134 ++++++------ crslab/data/dataset/redial/resources.py | 4 +- crslab/data/dataset/tgredial/tgredial.py | 180 +++++++++------- crslab/data/dataset/tokenizer/base.py | 5 +- crslab/data/dataset/tokenizer/bert.py | 4 +- crslab/data/dataset/tokenizer/gpt2.py | 4 +- crslab/data/dataset/tokenizer/jieba.py | 5 +- crslab/data/dataset/tokenizer/nltk.py | 8 +- crslab/data/dataset/tokenizer/pkuseg.py | 7 +- crslab/evaluator/embeddings.py | 2 +- crslab/evaluator/standard.py | 43 ++-- crslab/model/conversation/gpt2/gpt2.py | 20 +- crslab/model/crs/inspired/inspired_conv.py | 19 +- crslab/model/crs/inspired/inspired_rec.py | 6 +- crslab/model/crs/kbrd/kbrd.py | 78 ++++--- crslab/model/crs/kgsf/kgsf.py | 142 ++++++++----- crslab/model/crs/ntrd/ntrd.py | 201 +++++++++++------- crslab/model/crs/redial/modules.py | 48 +++-- crslab/model/crs/tgredial/tg_conv.py | 20 +- crslab/model/crs/tgredial/tg_policy.py | 14 +- crslab/model/crs/tgredial/tg_rec.py | 13 +- crslab/model/policy/conv_bert/conv_bert.py | 8 +- .../model/policy/profile_bert/profile_bert.py | 8 +- crslab/model/policy/topic_bert/topic_bert.py | 8 +- crslab/model/recommendation/bert/bert.py | 6 +- crslab/quick_start/quick_start.py | 24 ++- crslab/system/inspired.py | 14 +- crslab/system/kgsf.py | 21 +- crslab/system/redial.py | 17 +- crslab/system/tgredial.py | 57 +++-- 40 files changed, 931 insertions(+), 725 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..fe95ac3 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "editor.formatOnPaste": true, + "editor.formatOnSave": true +} \ No newline at end of file diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 7a4ad30..cca2126 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -87,7 +87,8 @@ def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: if dataset in dataset_register_table: return dataset_register_table[dataset](opt, tokenize, restore, save, task) else: - raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') + raise NotImplementedError( + f'The dataloader [{dataset}] has not been implemented') def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: @@ -106,4 +107,5 @@ def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: if model_name in dataloader_register_table: return dataloader_register_table[model_name](opt, dataset, vocab) else: - raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented') + raise NotImplementedError( + f'The dataloader [{model_name}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index aaeee1d..e94b8d9 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -26,20 +26,21 @@ import json import os from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources + class DuRecDialDataset(BaseDataset): """ @@ -77,7 +78,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -116,7 +117,8 @@ def _load_data(self): def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -130,17 +132,19 @@ def _load_raw_data(self): if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) logger.info('[Finish generate copy_mask]') - + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -148,21 +152,27 @@ def _load_raw_data(self): return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): - self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.tok2ind = json.load( + open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # entity kg with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f: self.entity2id = json.load(f) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join( + self.dpath, 'entity_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") @@ -172,7 +182,8 @@ def _load_other_data(self): self.word2id = json.load(f) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'hownet_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]") @@ -188,7 +199,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id( + conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -200,10 +212,14 @@ def _convert_to_id(self, conversation): for utt in conversation['dialog']: assert utt['role'] != last_role, print(utt) - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] + for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] augmented_convs.append({ "role": utt["role"], @@ -307,7 +323,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_token(self, data): all_data = [] for each in tqdm(data): each_dict = {} @@ -321,7 +337,7 @@ def split_token(self, data): each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) - + return all_data def generate_tok2ind(self, processed_train_data): @@ -354,11 +370,12 @@ def generate_tok2ind(self, processed_train_data): tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processing_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) @@ -374,21 +391,17 @@ def generate_copy_mask(self, tok2ind, processing_train_data): match_list += list_word for entity in dialog['entity']: list_word = crstokenize.tokenize(entity) - match_list += list_word - match_list = list(set(match_list)) + match_list += list_word + match_list = list(set(match_list)) for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -398,19 +411,18 @@ def generate_word2vec(self, processing_train_data): for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index bd2348b..9cf4f27 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -16,7 +16,7 @@ from crslab.download import DownloadableFile resources = { - 'resource':{ + 'resource': { 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', @@ -60,6 +60,6 @@ 'pad_word': 0, 'pad_topic': 0, }, - } - }, + } + }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index e337b10..6c84031 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -26,20 +26,21 @@ import json import os from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources + class GoRecDialDataset(BaseDataset): """ @@ -77,7 +78,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -117,7 +118,8 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -131,17 +133,19 @@ def _load_raw_data(self): if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) logger.info('[Finish generate copy_mask]') - + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -149,30 +153,38 @@ def _load_raw_data(self): return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): - self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.tok2ind = json.load( + open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join( + self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") # conceptnet # {concept: concept_id} - self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load( + open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -188,7 +200,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id( + conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -200,10 +213,14 @@ def _convert_to_id(self, conversation): for utt in conversation['dialog']: assert utt['role'] != last_role - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] + for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] policy = utt['decide'] augmented_convs.append({ @@ -224,7 +241,7 @@ def _augment_and_add(self, raw_conv_dict): entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \ - conv['policy'] + conv['policy'] if len(context_tokens) > 0 and len(text_tokens) > 0: conv_dict = { 'role': conv['role'], @@ -257,7 +274,8 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load( + open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -311,22 +329,22 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - str_text = one['text'] - tokenizer = self.tokenize - crstokenize = self.crstokenizer - list_text = crstokenize.tokenize(str_text) - one['text'] = list_text - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data + def split_token(self, data): + all_data = [] + for each in tqdm(data): + each_dict = {} + each_data = [] + for one in each['dialog']: + str_text = one['text'] + tokenizer = self.tokenize + crstokenize = self.crstokenizer + list_text = crstokenize.tokenize(str_text) + one['text'] = list_text + each_data.append(one) + each_dict['dialog'] = each_data + all_data.append(each_dict) + + return all_data def generate_tok2ind(self, processed_train_data): @@ -354,19 +372,20 @@ def generate_tok2ind(self, processed_train_data): if each_word not in tok2ind: tok2ind[each_word] = cnt cnt += 1 - + if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind def generate_copy_mask(self, tok2ind, processing_train_data): - + tokenizer = self.tokenize crstokenize = self.crstokenizer @@ -385,27 +404,19 @@ def generate_copy_mask(self, tok2ind, processing_train_data): for entity in dialog['entity']: list_word = crstokenize.tokenize(entity) match_list += list_word - + match_list = list(set(match_list)) - + for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) - path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') - np.save(path, copy_mask) + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -415,24 +426,20 @@ def generate_word2vec(self, processing_train_data): text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - - - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] - - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index 5ea42c1..e286ba2 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -58,7 +58,7 @@ 'pad_entity': 0, 'pad_word': 0 }, - }, + }, }, - + } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index e7adb19..860fc29 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -26,20 +26,20 @@ import json import os from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources class InspiredDataset(BaseDataset): @@ -78,7 +78,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -118,7 +118,8 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -132,17 +133,19 @@ def _load_raw_data(self): if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) logger.info('[Finish generate copy_mask]') - + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") - + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -154,18 +157,23 @@ def _load_vocab(self): self.tok2ind = json.load(f) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f: self.entity2id = json.load(f) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join( + self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") @@ -175,7 +183,8 @@ def _load_other_data(self): self.word2id = json.load(f) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'concept_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -191,7 +200,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id( + conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -201,10 +211,14 @@ def _convert_to_id(self, conversation): augmented_convs = [] last_role = None for utt in conversation['dialog']: - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] + for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] if utt["role"] == last_role: augmented_convs[-1]["text"] += text_token_ids @@ -315,7 +329,7 @@ def _word_kg_process(self): } def split_token(self, data): - + all_data = [] for each in tqdm(data): each_dict = {} @@ -329,7 +343,7 @@ def split_token(self, data): each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) - + return all_data def generate_tok2ind(self, processed_train_data): @@ -358,19 +372,20 @@ def generate_tok2ind(self, processed_train_data): if each_word not in tok2ind: tok2ind[each_word] = cnt cnt += 1 - + if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind def generate_copy_mask(self, tok2ind, processing_train_data): - + tokenizer = self.tokenize crstokenize = self.crstokenizer @@ -397,27 +412,19 @@ def generate_copy_mask(self, tok2ind, processing_train_data): for people in dialog['people']: list_word = crstokenize.tokenize(people) match_list += list_word - + match_list = list(set(match_list)) - + for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) - path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') - np.save(path, copy_mask) + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -427,23 +434,20 @@ def generate_word2vec(self, processing_train_data): text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] + elif self.tokenize == 'jieba': + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index c2d1e75..77f915e 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -23,7 +23,7 @@ 'inspired.zip', '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'nltk':{ + 'nltk': { 'special_token_idx': { 'pad': 0, 'start': 1, diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index f0b4e55..482cc89 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -27,20 +27,20 @@ import os from collections import defaultdict from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources class OpenDialKGDataset(BaseDataset): @@ -79,7 +79,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -119,7 +119,8 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -136,14 +137,16 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -151,30 +154,38 @@ def _load_raw_data(self): return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): - self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.tok2ind = json.load( + open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # opendialkg self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join( + self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]") # conceptnet # {concept: concept_id} - self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load( + open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'concept_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -190,7 +201,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id( + conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -200,10 +212,14 @@ def _convert_to_id(self, conversation): augmented_convs = [] last_role = None for utt in conversation['dialog']: - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] + for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] if utt["role"] == last_role: augmented_convs[-1]["text"] += text_token_ids @@ -258,7 +274,8 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) + item_entity_ids = json.load( + open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load item entity ids]') side_data = { @@ -283,7 +300,8 @@ def _entity_kg_process(self): if e1 != e0: edge_list.append((e1, e1, 'SELF_LOOP')) - relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict( + int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -316,9 +334,9 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - + def split_token(self, data): - + all_data = [] for each in tqdm(data): each_dict = {} @@ -332,7 +350,7 @@ def split_token(self, data): each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) - + return all_data def generate_tok2ind(self, processed_train_data): @@ -361,19 +379,21 @@ def generate_tok2ind(self, processed_train_data): if each_word not in tok2ind: tok2ind[each_word] = cnt cnt += 1 - + if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'opendialkg', 'token2id.json') + tok2ind_path = os.path.join( + DATASET_PATH, 'opendialkg', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind def generate_copy_mask(self, tok2ind, processing_train_data): - + tokenizer = self.tokenize crstokenize = self.crstokenizer @@ -392,27 +412,19 @@ def generate_copy_mask(self, tok2ind, processing_train_data): for item in dialog['item']: list_word = crstokenize.tokenize(item) match_list += list_word - + match_list = list(set(match_list)) - + for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') - np.save(path, copy_mask) + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -422,24 +434,21 @@ def generate_word2vec(self, processing_train_data): text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) - + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] - - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] - - word2vec_path = os.path.join(DATASET_PATH, 'opendialkg', 'word2vec.npy') + word2vec_path = os.path.join( + DATASET_PATH, 'opendialkg', 'word2vec.npy') np.save(word2vec_path, word2embedding) - diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index e5682fe..9ff0eed 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -23,7 +23,7 @@ 'opendialkg.zip', '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'nltk':{ + 'nltk': { 'special_token_idx': { 'pad': 0, 'start': 1, @@ -56,6 +56,6 @@ 'pad_entity': 0, 'pad_word': 0 }, - } + } }, } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 861215f..5f28d70 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -27,20 +27,20 @@ import os from collections import defaultdict from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources class ReDialDataset(BaseDataset): @@ -79,7 +79,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -119,7 +119,8 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -136,14 +137,16 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -151,30 +154,38 @@ def _load_raw_data(self): return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): - self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.tok2ind = json.load( + open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8')) + self.entity_kg = json.load( + open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8')) logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") # conceptNet # {concept: concept_id} - self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load( + open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} - self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]") @@ -190,7 +201,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._merge_conv_data(conversation["dialog"]) for conversation in tqdm(raw_data)] + augmented_convs = [self._merge_conv_data( + conversation["dialog"]) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -200,10 +212,14 @@ def _merge_conv_data(self, dialog): augmented_convs = [] last_role = None for utt in dialog: - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] + for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] if utt["role"] == last_role: augmented_convs[-1]["text"] += text_token_ids @@ -258,7 +274,8 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load( + open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -276,10 +293,13 @@ def _entity_kg_process(self, SELF_LOOP_ID=185): edge_list.append((entity, entity, SELF_LOOP_ID)) # add self loop for tail_and_relation in self.entity_kg[str(entity)]: if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID: - edge_list.append((entity, tail_and_relation[1], tail_and_relation[0])) - edge_list.append((tail_and_relation[1], entity, tail_and_relation[0])) + edge_list.append( + (entity, tail_and_relation[1], tail_and_relation[0])) + edge_list.append( + (tail_and_relation[1], entity, tail_and_relation[0])) - relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict( + int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -311,9 +331,9 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - + def split_token(self, data): - + all_data = [] for each in tqdm(data): each_dict = {} @@ -327,7 +347,7 @@ def split_token(self, data): each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) - + return all_data def generate_tok2ind(self, processed_train_data): @@ -356,19 +376,20 @@ def generate_tok2ind(self, processed_train_data): if each_word not in tok2ind: tok2ind[each_word] = cnt cnt += 1 - + if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind def generate_copy_mask(self, tok2ind, processing_train_data): - + tokenizer = self.tokenize crstokenize = self.crstokenizer @@ -386,27 +407,19 @@ def generate_copy_mask(self, tok2ind, processing_train_data): for entity in dialog['entity']: list_word = crstokenize.tokenize(entity) match_list += list_word - + match_list = list(set(match_list)) - + for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) - path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') - np.save(path, copy_mask) + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -416,23 +429,20 @@ def generate_word2vec(self, processing_train_data): text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) - + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) + if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] - - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') - np.save(word2vec_path, word2embedding) \ No newline at end of file + np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index 170dd3b..9cd1e27 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -23,7 +23,7 @@ 'redial.zip', '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'nltk':{ + 'nltk': { 'special_token_idx': { 'pad': 0, 'start': 1, @@ -56,6 +56,6 @@ 'pad_entity': 0, 'pad_word': 0 }, - } + } }, } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 83a850e..05d450b 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -27,20 +27,20 @@ import os from collections import defaultdict from copy import copy -import numpy as np -import gensim - -from loguru import logger -from tqdm import tqdm +import gensim +import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from .resources import resources -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.bert import bert_tokenize from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize from crslab.data.dataset.tokenizer.jieba import jieba_tokenize +from crslab.data.dataset.tokenizer.nltk import nltk_tokenize from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize +from loguru import logger +from tqdm import tqdm + +from .resources import resources class TGReDialDataset(BaseDataset): @@ -82,7 +82,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): """ if 'copy' in opt: - self.copy = True + self.copy = True else: self.copy = False resource = resources['resource'] @@ -100,15 +100,15 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.crstokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'tgredial') - self.replace_token = opt.get('replace_token',None) - self.replace_token_idx = opt.get('replace_token_idx',None) + self.replace_token = opt.get('replace_token', None) + self.replace_token_idx = opt.get('replace_token_idx', None) super().__init__(opt, dpath, resource, restore, save) if self.replace_token: if self.replace_token_idx: self.side_data["embedding"][self.replace_token_idx] = self.side_data['embedding'][0] else: - self.side_data["embedding"] = np.insert(self.side_data["embedding"],len(self.side_data["embedding"]),self.side_data['embedding'][0],axis=0) - + self.side_data["embedding"] = np.insert(self.side_data["embedding"], len( + self.side_data["embedding"]), self.side_data['embedding'][0], axis=0) def _load_data(self): train_data, valid_data, test_data = self._load_raw_data() @@ -136,7 +136,8 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + logger.debug( + f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") # split token processing_train_data = self.split_token(train_data) logger.info("[Finish train data split]") @@ -153,14 +154,16 @@ def _load_raw_data(self): with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + logger.debug( + f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") # split_token processing_valid_data = self.split_token(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + logger.debug( + f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") # split_token processing_test_data = self.split_token(test_data) logger.info("[Finish test data split]") @@ -168,7 +171,8 @@ def _load_raw_data(self): return processing_train_data, processing_valid_data, processing_test_data def _load_vocab(self): - self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.tok2ind = json.load( + open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} # add special tokens if self.replace_token: @@ -180,46 +184,61 @@ def _load_vocab(self): else: self.ind2tok[len(self.tok2ind)] = self.replace_token self.tok2ind[self.replace_token] = len(self.tok2ind) - self.special_token_idx[self.replace_token] = len(self.tok2ind)-1 - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + self.special_token_idx[self.replace_token] = len( + self.tok2ind)-1 + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2tok)}]") - self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8')) + self.topic2ind = json.load( + open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8')) self.ind2topic = {idx: word for word, idx in self.topic2ind.items()} - logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]") - logger.debug(f"[The size of token2index dictionary is {len(self.topic2ind)}]") - logger.debug(f"[The size of index2token dictionary is {len(self.ind2topic)}]") + logger.debug( + f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]") + logger.debug( + f"[The size of token2index dictionary is {len(self.topic2ind)}]") + logger.debug( + f"[The size of index2token dictionary is {len(self.ind2topic)}]") def _load_other_data(self): # cn-dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, + idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join( + self.dpath, 'cn-dbpedia.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]") # hownet # {concept: concept_id} - self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load( + open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} - self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8') + self.word_kg = open(os.path.join( + self.dpath, 'hownet.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]") # user interaction history dictionary - self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8')) - logger.debug(f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]") + self.conv2history = json.load( + open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8')) + logger.debug( + f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]") # user profile - self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8')) - logger.debug(f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}") - + self.user2profile = json.load( + open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8')) + logger.debug( + f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}") def _data_preprocess(self, train_data, valid_data, test_data): processed_train_data = self._raw_data_process(train_data) @@ -233,7 +252,8 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id( + conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -247,14 +267,19 @@ def _convert_to_id(self, conversation): # change movies into slots if self.replace_token: if len(utt['movie']) != 0: - while '怊' in utt['text'] : + while '怊' in utt['text']: begin = utt['text'].index("怊") end = utt['text'].index("怋") - utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:] - text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + utt['text'] = utt['text'][:begin] + \ + [self.replace_token] + utt['text'][end+1:] + text_token_ids = [self.tok2ind.get( + word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] + for movie in utt['movie'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] + for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] + for word in utt['word'] if word in self.word2id] policy = [] for action, kw in zip(utt['target'][1::2], utt['target'][2::2]): if kw is None or action == 'ęŽØ荐ē”µå½±': @@ -263,12 +288,15 @@ def _convert_to_id(self, conversation): kw = [kw] kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw] policy.append([action, kw]) - final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]] + final_kws = [ + self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]] final = [utt['final'][0], final_kws] - conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id']) + conv_utt_id = str( + conversation['conv_id']) + '/' + str(utt['local_id']) interaction_history = self.conv2history.get(conv_utt_id, []) user_profile = self.user2profile[conversation['user_id']] - user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile] + user_profile = [[self.tok2ind.get( + token, self.unk_token_idx) for token in sent] for sent in user_profile] augmented_convs.append({ "role": utt["role"], @@ -291,11 +319,11 @@ def _augment_and_add(self, raw_conv_dict): entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \ - conv['policy'] - if self.replace_token is not None: + conv['policy'] + if self.replace_token is not None: if text_tokens.count(30000) != len(movies): - continue # the number of slots doesn't equal to the number of movies - + continue # the number of slots doesn't equal to the number of movies + if len(context_tokens) > 0: conv_dict = { 'role': conv['role'], @@ -332,7 +360,8 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load( + open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -355,7 +384,8 @@ def _entity_kg_process(self): if e1 != e0: edge_list.append((e1, e1, 'SELF_LOOP')) - relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict( + int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -389,7 +419,7 @@ def _word_kg_process(self): } def split_token(self, data): - + all_data = [] for each in tqdm(data): each_dict = {} @@ -405,7 +435,7 @@ def split_token(self, data): each_dict['messages'] = each_data each_dict['user_id'] = each['user_id'] all_data.append(each_dict) - + return all_data def generate_tok2ind(self, processed_train_data): @@ -434,19 +464,20 @@ def generate_tok2ind(self, processed_train_data): if each_word not in tok2ind: tok2ind[each_word] = cnt cnt += 1 - + if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, indent=4, separators=(',', ':')) + json.dump(tok2ind, write, ensure_ascii=False, + indent=4, separators=(',', ':')) return tok2ind def generate_copy_mask(self, tok2ind, processing_train_data): - + tokenizer = self.tokenize crstokenize = self.crstokenizer @@ -466,27 +497,19 @@ def generate_copy_mask(self, tok2ind, processing_train_data): for entity in dialog['entity']: list_word = crstokenize.tokenize(entity) match_list += list_word - + match_list = list(set(match_list)) - + for each_word in text: if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True - if not os.path.exists(MODEL_PATH): - os.mkdir(MODEL_PATH) - - if not os.path.exists(os.path.join(MODEL_PATH, 'kgsf')): - os.mkdir(os.path.join(MODEL_PATH, 'kgsf')) - - copy_mask_dirpath = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') - if not os.path.exists(copy_mask_dirpath): - os.mkdir(copy_mask_dirpath) - path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') - np.save(path, copy_mask) + if not os.path.exists(path): + os.makedirs(path) + np.save(path, copy_mask) def generate_word2vec(self, processing_train_data): @@ -496,23 +519,20 @@ def generate_word2vec(self, processing_train_data): text = dialog['text'] corpus.append(text) - model = gensim.models.word2vec.Word2Vec(corpus, vector_size=300, min_count=1) + model = gensim.models.word2vec.Word2Vec( + corpus, vector_size=300, min_count=1) if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] + [[0] * 300] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] + [[0] * 300] elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': - word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - - elif self.tokenize == 'bert': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] + [model.wv[word] for word in word2index] - - elif self.tokenize == 'gpt2': - word2index = {word: i + 1 for i, word in enumerate(model.wv.index_to_key)} - word2embedding = [model.wv[word] for word in word2index] + word2index = {word: i + 4 for i, + word in enumerate(model.wv.index_to_key)} + word2embedding = [[0] * 300] * 4 + [model.wv[word] + for word in word2index] word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') np.save(word2vec_path, word2embedding) diff --git a/crslab/data/dataset/tokenizer/base.py b/crslab/data/dataset/tokenizer/base.py index 153a3e3..d9ccd22 100644 --- a/crslab/data/dataset/tokenizer/base.py +++ b/crslab/data/dataset/tokenizer/base.py @@ -2,9 +2,6 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -import os -from transformers import AutoTokenizer - class BaseCrsTokenize: def __init__(self, path=None) -> None: @@ -14,4 +11,4 @@ def tokenize(self, text): ''' split token ''' - pass \ No newline at end of file + pass diff --git a/crslab/data/dataset/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py index 0ee5463..fc74bcb 100644 --- a/crslab/data/dataset/tokenizer/bert.py +++ b/crslab/data/dataset/tokenizer/bert.py @@ -2,9 +2,9 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize from transformers import AutoTokenizer -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class bert_tokenize(BaseCrsTokenize): @@ -13,4 +13,4 @@ def __init__(self, path=None) -> None: self.my_tokenizer = AutoTokenizer.from_pretrained(path) def tokenize(self, text): - return self.my_tokenizer.tokenize(text) \ No newline at end of file + return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py index b52747d..2ddb0c1 100644 --- a/crslab/data/dataset/tokenizer/gpt2.py +++ b/crslab/data/dataset/tokenizer/gpt2.py @@ -2,9 +2,9 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize from transformers import AutoTokenizer -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class gpt2_tokenize(BaseCrsTokenize): @@ -13,4 +13,4 @@ def __init__(self, path=None) -> None: self.my_tokenizer = AutoTokenizer.from_pretrained(path) def tokenize(self, text): - return self.my_tokenizer.tokenize(text) \ No newline at end of file + return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py index a834995..753f0ba 100644 --- a/crslab/data/dataset/tokenizer/jieba.py +++ b/crslab/data/dataset/tokenizer/jieba.py @@ -2,9 +2,10 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize + import jieba -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class jieba_tokenize(BaseCrsTokenize): @@ -14,4 +15,4 @@ def __init__(self, path=None) -> None: def tokenize(self, text): split_text = jieba.cut(text) text_list = ' '.join(split_text).split() - return text_list \ No newline at end of file + return text_list diff --git a/crslab/data/dataset/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py index 94b318c..4e016f1 100644 --- a/crslab/data/dataset/tokenizer/nltk.py +++ b/crslab/data/dataset/tokenizer/nltk.py @@ -2,9 +2,11 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize + from nltk import word_tokenize +import nltk -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class nltk_tokenize(BaseCrsTokenize): @@ -12,5 +14,5 @@ def __init__(self, path=None) -> None: super().__init__(path) def tokenize(self, text): - # nltk.download('punkt') - return word_tokenize(text) \ No newline at end of file + nltk.download('punkt') + return word_tokenize(text) diff --git a/crslab/data/dataset/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py index 2a4e8fa..fabdcad 100644 --- a/crslab/data/dataset/tokenizer/pkuseg.py +++ b/crslab/data/dataset/tokenizer/pkuseg.py @@ -2,9 +2,10 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com +from crslab.data.dataset.tokenizer.base import BaseCrsTokenize + import pkuseg -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize class pkuseg_tokenize(BaseCrsTokenize): @@ -12,5 +13,5 @@ def __init__(self, path=None) -> None: self.pkuseg_tokenizer = pkuseg.pkuseg() super().__init__(path) - def tokenize(self, text): - return self.pkuseg_tokenizer.cut(text) \ No newline at end of file + def tokenize(self, text): + return self.pkuseg_tokenizer.cut(text) diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index b682e42..a37adda 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -25,7 +25,7 @@ ) }, 'en': { - 'version': '.0', + 'version': '1.0', 'file': DownloadableFile( 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index f08d121..67ac256 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -17,21 +17,21 @@ from collections import defaultdict import fasttext +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import nice_report from loguru import logger from nltk import ngrams from torch.utils.tensorboard import SummaryWriter -from crslab.evaluator.base import BaseEvaluator -from crslab.evaluator.utils import nice_report -from .embeddings import resources -from .metrics import * from ..config import EMBEDDING_PATH from ..download import build +from .embeddings import resources +from .metrics import * class StandardEvaluator(BaseEvaluator): """The evaluator for all kind of model(recommender, conversation, policy) - + Args: rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K dist_set: the set to record dist n-gram @@ -54,8 +54,10 @@ def __init__(self, language, tensorboard=False): # tensorboard self.tensorboard = tensorboard if self.tensorboard: - self.writer = SummaryWriter(log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) - self.reports_name = ['Recommendation Metrics', 'Generation Metrics', 'Optimization Metrics'] + self.writer = SummaryWriter( + log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + self.reports_name = ['Recommendation Metrics', + 'Generation Metrics', 'Optimization Metrics'] def _load_embedding(self, language): resource = resources[language] @@ -72,16 +74,20 @@ def _get_sent_embedding(self, sent): def rec_evaluate(self, ranks, label): for k in [1, 10, 50]: if len(ranks) >= k: - self.rec_metrics.add(f"hit@{k}", HitMetric.compute(ranks, label, k)) - self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) - self.rec_metrics.add(f"mrr@{k}", MRRMetric.compute(ranks, label, k)) + self.rec_metrics.add( + f"hit@{k}", HitMetric.compute(ranks, label, k)) + self.rec_metrics.add( + f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) + self.rec_metrics.add( + f"mrr@{k}", MRRMetric.compute(ranks, label, k)) def gen_evaluate(self, hyp, refs): if hyp: self.gen_metrics.add("f1", F1Metric.compute(hyp, refs)) for k in range(1, 5): - self.gen_metrics.add(f"bleu@{k}", BleuMetric.compute(hyp, refs, k)) + self.gen_metrics.add( + f"bleu@{k}", BleuMetric.compute(hyp, refs, k)) for token in ngrams(hyp, k): self.dist_set[f"dist@{k}"].add(token) self.dist_cnt += 1 @@ -89,18 +95,23 @@ def gen_evaluate(self, hyp, refs): hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] if len(ref_embs[0]) > 0: - self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + self.gen_metrics.add( + 'greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add( + 'average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add( + 'extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) - reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] + reports = [self.rec_metrics.report(), self.gen_metrics.report(), + self.optim_metrics.report()] if self.tensorboard and mode != 'test': for idx, task_report in enumerate(reports): for each_metric, value in task_report.items(): - self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) + self.writer.add_scalars( + f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) def reset_metrics(self): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index e394220..2f7a345 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -23,20 +23,16 @@ """ -import os import torch +from crslab.model.base import BaseModel from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class GPT2Model(BaseModel): """ - + Attributes: context_truncate: A integer indicating the length of dialogue context. response_truncate: A integer indicating the length of dialogue response. @@ -97,7 +93,8 @@ def generate(self, context): context = context[..., -self.response_truncate + 1:] for i in range(self.response_truncate - 1): - outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s), + outputs = self.model( + context, former_hidden_state) # (bs, c_t, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s) @@ -141,8 +138,10 @@ def generate_bs(self, context, beam=4): next_token_logits = last_hidden_state[:, -1, :] next_token_probs = torch.nn.functional.softmax(next_token_logits) topk = torch.topk(next_token_probs, beam, dim=-1) - probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam) - preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam) + probs = topk.values.reshape( + [batch_size, -1, beam]) # (bs, candidate, beam) + preds = topk.indices.reshape( + [batch_size, -1, beam]) # (bs, candidate, beam) for j in range(batch_size): all_candidates = [] @@ -154,7 +153,8 @@ def generate_bs(self, context, beam=4): seq_tmp.append(preds[j][n][k]) candidate = [seq_tmp, prob * probs[j][n][k]] all_candidates.append(candidate) - ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) + ordered = sorted( + all_candidates, key=lambda tup: tup[1], reverse=True) sequences[j] = ordered[:beam] res = [] diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 71eee2e..19e58ad 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -7,14 +7,12 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -import os -import json -import torch -from transformers import GPT2LMHeadModel, GPT2Config -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map + +import torch from crslab.model.base import BaseModel +from transformers import GPT2Config, GPT2LMHeadModel + from .modules import SequenceCrossEntropyLoss @@ -90,7 +88,8 @@ def converse(self, batch, mode): lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) - lm_logits_all = torch.cat(lm_logits_all, dim=0) # (b_s, seq_len, vocab_size) + # (b_s, seq_len, vocab_size) + lm_logits_all = torch.cat(lm_logits_all, dim=0) # index from 1 to self.reponse_truncate is valid response loss = self.calculate_loss( @@ -122,9 +121,11 @@ def generate(self, roles, context): context_iters = context.unsqueeze(1) for turn, iter in enumerate(context_iters): if roles[turn] == 0: - outputs = self.model_sk(iter, former_hidden_state) # (1, s_l, v_s), + outputs = self.model_sk( + iter, former_hidden_state) # (1, s_l, v_s), else: - outputs = self.model_rm(iter, former_hidden_state) # (1, s_l, v_s), + outputs = self.model_rm( + iter, former_hidden_state) # (1, s_l, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values last_hidden_state_all.append(last_hidden_state) diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index 5efd93d..bebbd83 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -23,16 +23,12 @@ """ -import os +from crslab.model.base import BaseModel from loguru import logger from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class InspiredRecModel(BaseModel): """ diff --git a/crslab/model/crs/kbrd/kbrd.py b/crslab/model/crs/kbrd/kbrd.py index 987151d..81b6c4c 100644 --- a/crslab/model/crs/kbrd/kbrd.py +++ b/crslab/model/crs/kbrd/kbrd.py @@ -21,14 +21,14 @@ import torch import torch.nn.functional as F -from loguru import logger -from torch import nn -from torch_geometric.nn import RGCNConv - from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionBatch -from crslab.model.utils.modules.transformer import TransformerDecoder, TransformerEncoder +from crslab.model.utils.modules.transformer import (TransformerDecoder, + TransformerEncoder) +from loguru import logger +from torch import nn +from torch_geometric.nn import RGCNConv class KBRDModel(BaseModel): @@ -84,7 +84,8 @@ def __init__(self, opt, device, vocab, side_data): self.n_entity = vocab['n_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] - self.edge_idx, self.edge_type = edge_to_pyg_format(entity_kg['edge'], 'RGCN') + self.edge_idx, self.edge_type = edge_to_pyg_format( + entity_kg['edge'], 'RGCN') self.edge_idx = self.edge_idx.to(device) self.edge_type = self.edge_type.to(device) self.num_bases = opt.get('num_bases', 8) @@ -98,7 +99,8 @@ def __init__(self, opt, device, vocab, side_data): self.attention_dropout = opt.get('attention_dropout', 0.0) self.relu_dropout = opt.get('relu_dropout', 0.1) self.embeddings_scale = opt.get('embedding_scale', True) - self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False) + self.learn_positional_embeddings = opt.get( + 'learn_positional_embeddings', False) self.reduction = opt.get('reduction', False) self.n_positions = opt.get('n_positions', 1024) self.longest_label = opt.get('longest_label', 1) @@ -118,13 +120,17 @@ def _build_embedding(self): torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) + self.token_embedding = nn.Embedding( + self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, + mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_( + self.token_embedding.weight[self.pad_token_idx], 0) logger.debug('[Build embedding]') def _build_kg_layer(self): - self.kg_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) + self.kg_encoder = RGCNConv( + self.n_entity, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim) logger.debug('[Build kg layer]') @@ -134,7 +140,8 @@ def _build_recommendation_layer(self): logger.debug('[Build recommendation layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor( + [self.start_token_idx], dtype=torch.long)) self.dialog_encoder = TransformerEncoder( self.n_heads, self.n_layers, @@ -175,7 +182,8 @@ def encode_user(self, entity_lists, kg_embedding): user_repr_list = [] for entity_list in entity_lists: if entity_list is None: - user_repr_list.append(torch.zeros(self.user_emb_dim, device=self.device)) + user_repr_list.append(torch.zeros( + self.user_emb_dim, device=self.device)) continue user_repr = kg_embedding[entity_list] user_repr = self.kg_attn(user_repr) @@ -201,7 +209,8 @@ def decode_forced(self, encoder_states, user_embedding, resp): inputs = torch.cat([self._starts(bsz), inputs], 1) latent, _ = self.decoder(inputs, encoder_states) token_logits = F.linear(latent, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu( + self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits _, preds = sum_logits.max(dim=-1) return sum_logits, preds @@ -213,16 +222,19 @@ def decode_greedy(self, encoder_states, user_embedding): incr_state = None logits = [] for i in range(self.longest_label): - scores, incr_state = self.decoder(xs, encoder_states, incr_state) # incr_state is always None + scores, incr_state = self.decoder( + xs, encoder_states, incr_state) # incr_state is always None scores = scores[:, -1:, :] token_logits = F.linear(scores, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu( + self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits probs, preds = sum_logits.max(dim=-1) logits.append(scores) xs = torch.cat([xs, preds], dim=1) # check if everyone has generated an end token - all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz + all_finished = ((xs == self.end_token_idx).sum( + dim=1) > 0).sum().item() == bsz if all_finished: break logits = torch.cat(logits, 1) @@ -240,7 +252,8 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): for j in range(bsz): text = sequences[j][d][0] xs.append(text) - xs = torch.stack(xs).reshape(beam, bsz, -1) # (beam, batch_size, _) + xs = torch.stack(xs).reshape( + beam, bsz, -1) # (beam, batch_size, _) with torch.no_grad(): if i == 1: @@ -248,15 +261,18 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): encoder_states = (encoder_states[0].repeat(beam, 1, 1), encoder_states[1].repeat(beam, 1, 1)) - scores, _ = self.decoder(xs.reshape(len(sequences[0]) * bsz, -1), encoder_states) + scores, _ = self.decoder(xs.reshape( + len(sequences[0]) * bsz, -1), encoder_states) scores = scores[:, -1:, :] token_logits = F.linear(scores, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu( + self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits logits = sum_logits.reshape(len(sequences[0]), bsz, 1, -1) scores = scores.reshape(len(sequences[0]), bsz, 1, -1) - logits = torch.nn.functional.softmax(logits) # turn into probabilities,in case of negative numbers + # turn into probabilities,in case of negative numbers + logits = torch.nn.functional.softmax(logits) probs, preds = logits.topk(beam, dim=-1) # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam @@ -269,15 +285,20 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): if score == []: score_tmp = scores[n][j][0].unsqueeze(0) else: - score_tmp = torch.cat((score, scores[n][j][0].unsqueeze(0)), dim=0) - seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) - candidate = [seq_tmp, score_tmp, prob * probs[n][j][0][k]] + score_tmp = torch.cat( + (score, scores[n][j][0].unsqueeze(0)), dim=0) + seq_tmp = torch.cat( + (xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) + candidate = [seq_tmp, score_tmp, + prob * probs[n][j][0][k]] all_candidates.append(candidate) - ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) + ordered = sorted( + all_candidates, key=lambda tup: tup[2], reverse=True) sequences[j] = ordered[:beam] # check if everyone has generated an end token - all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz + all_finished = ((xs == self.end_token_idx).sum( + dim=1) > 0).sum().item() == bsz if all_finished: break logits = torch.stack([seq[0][1] for seq in sequences]) @@ -292,7 +313,8 @@ def converse(self, batch, mode): encoder_state = self.dialog_encoder(context_tokens) if mode != 'test': self.longest_label = max(self.longest_label, response.shape[1]) - logits, preds = self.decode_forced(encoder_state, user_embedding, response) + logits, preds = self.decode_forced( + encoder_state, user_embedding, response) logits = logits.view(-1, logits.shape[-1]) labels = response.view(-1) return self.conv_loss(logits, labels), preds @@ -313,4 +335,4 @@ def freeze_parameters(self): freeze_models = [self.kg_encoder, self.kg_attn, self.rec_bias] for model in freeze_models: for p in model.parameters(): - p.requires_grad = False \ No newline at end of file + p.requires_grad = False diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 1230ec8..64b37be 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -28,15 +28,15 @@ import numpy as np import torch import torch.nn.functional as F -from loguru import logger -from torch import nn -from torch_geometric.nn import GCNConv, RGCNConv - from crslab.config import MODEL_PATH from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder +from loguru import logger +from torch import nn +from torch_geometric.nn import GCNConv, RGCNConv + from .modules import GateLayer, TransformerDecoderKG @@ -98,7 +98,8 @@ def __init__(self, opt, device, vocab, side_data): entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] - self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN') + self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format( + entity_edges, 'RGCN') self.entity_edge_idx = self.entity_edge_idx.to(device) self.entity_edge_type = self.entity_edge_type.to(device) word_edges = side_data['word_kg']['edge'] @@ -137,24 +138,32 @@ def _init_embeddings(self): torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) - - self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx) - nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + self.token_embedding = nn.Embedding( + self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, + mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_( + self.token_embedding.weight[self.pad_token_idx], 0) + + self.word_kg_embedding = nn.Embedding( + self.n_word, self.kg_emb_dim, self.pad_word_idx) + nn.init.normal_(self.word_kg_embedding.weight, + mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0) logger.debug('[Finish init embeddings]') def _build_kg_layer(self): # db encoder - self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) - self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) + self.entity_encoder = RGCNConv( + self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) + self.entity_self_attn = SelfAttentionSeq( + self.kg_emb_dim, self.kg_emb_dim) # concept encoder self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim) - self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) + self.word_self_attn = SelfAttentionSeq( + self.kg_emb_dim, self.kg_emb_dim) # gate mechanism self.gate_layer = GateLayer(self.kg_emb_dim) @@ -175,7 +184,8 @@ def _build_recommendation_layer(self): logger.debug('[Finish build rec layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor( + [self.start_token_idx], dtype=torch.long)) self.conv_encoder = TransformerEncoder( n_heads=self.n_heads, n_layers=self.n_layers, @@ -229,15 +239,19 @@ def pretrain_infomax(self, batch): if loss_mask.item() == 0: return None - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) word_representations = word_graph_representations[words] word_padding_mask = words.eq(self.pad_word_idx) # (bs, seq_len) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) - info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) + info_predict = F.linear( + word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) loss = self.infomax_loss(info_predict, entity_labels) / loss_mask return loss @@ -249,20 +263,27 @@ def recommend(self, batch, mode): """ context_entities, context_words, entities, movie = batch - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, word_len) + entity_padding_mask = context_entities.eq( + self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq( + self.pad_word_idx) # (bs, word_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn( + entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) user_rep = self.gate_layer(entity_attn_rep, word_attn_rep) - rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) + rec_scores = F.linear( + user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) rec_loss = self.rec_loss(rec_scores, movie) @@ -273,7 +294,8 @@ def recommend(self, batch, mode): word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) - info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask + info_loss = self.infomax_loss( + info_predict, entities) / info_loss_mask return rec_loss, info_loss, rec_scores @@ -303,7 +325,8 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze( 0) # (bs, seq_len, vocab_size) - gen_logits = F.linear(dialog_latent, self.token_embedding.weight) # (bs, seq_len, vocab_size) + # (bs, seq_len, vocab_size) + gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1) return sum_logits, preds @@ -320,16 +343,19 @@ def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm( + torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + copy_logits = self.copy_output( + copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1).long() logits.append(sum_logits) inputs = torch.cat((inputs, preds), dim=1) - finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size + finished = ((inputs == self.end_token_idx).sum( + dim=-1) > 0).sum().item() == batch_size if finished: break logits = torch.cat(logits, dim=1) @@ -360,7 +386,8 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at for j in range(batch_size): text = sequences[j][d][0] inputs.append(text) - inputs = torch.stack(inputs).reshape(beam, batch_size, -1) # (beam, batch_size, _) + inputs = torch.stack(inputs).reshape( + beam, batch_size, -1) # (beam, batch_size, _) with torch.no_grad(): dialog_latent, incr_state = self.conv_decoder( @@ -371,15 +398,19 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm( + torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) - gen_logits = F.linear(dialog_latent, self.token_embedding.weight) + copy_logits = self.copy_output( + copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + gen_logits = F.linear( + dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits logits = sum_logits.reshape(len(sequences[0]), batch_size, 1, -1) # turn into probabilities,in case of negative numbers - probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1) + probs, preds = torch.nn.functional.softmax( + logits).topk(beam, dim=-1) # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam @@ -392,15 +423,20 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at if logit == []: logit_tmp = logits[n][j][0].unsqueeze(0) else: - logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0) - seq_tmp = torch.cat((inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) - candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]] + logit_tmp = torch.cat( + (logit, logits[n][j][0].unsqueeze(0)), dim=0) + seq_tmp = torch.cat( + (inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) + candidate = [seq_tmp, logit_tmp, + prob * probs[n][j][0][k]] all_candidates.append(candidate) - ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) + ordered = sorted( + all_candidates, key=lambda tup: tup[2], reverse=True) sequences[j] = ordered[:beam] # check if everyone has generated an end token - all_finished = ((inputs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size + all_finished = ((inputs == self.end_token_idx).sum( + dim=1) > 0).sum().item() == batch_size if all_finished: break logits = torch.stack([seq[0][1] for seq in sequences]) @@ -410,17 +446,23 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at def converse(self, batch, mode): context_tokens, context_entities, context_words, response = batch - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, seq_len) + entity_padding_mask = context_entities.eq( + self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq( + self.pad_word_idx) # (bs, seq_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn( + entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) # encoder-decoder tokens_encoding = self.conv_encoder(context_tokens) @@ -447,8 +489,10 @@ def converse(self, batch, mode): def forward(self, batch, stage, mode): if len(self.gpu) >= 2: # forward function operates on different gpus, the weight of graph network need to be copied to other gpu - self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device()) - self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device()) + self.entity_edge_idx = self.entity_edge_idx.cuda( + torch.cuda.current_device()) + self.entity_edge_type = self.entity_edge_type.cuda( + torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), ).cuda(torch.cuda.current_device()) diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index ef85782..23c366a 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -23,16 +23,18 @@ import numpy as np import torch import torch.nn.functional as F -from loguru import logger -from torch import nn -from torch_geometric.nn import GCNConv, RGCNConv - from crslab.config import MODEL_PATH from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder -from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection +from loguru import logger +from torch import nn +from torch_geometric.nn import GCNConv, RGCNConv + +from .modules import (GateLayer, TransformerDecoderKG, + TransformerDecoderSelection) + class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -54,7 +56,7 @@ def __init__(self, opt, device, vocab, side_data): self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) - self.replace_token = opt.get('replace_token',None) + self.replace_token = opt.get('replace_token', None) self.replace_token_idx = vocab[self.replace_token] # kg self.n_word = vocab['n_word'] @@ -64,7 +66,8 @@ def __init__(self, opt, device, vocab, side_data): entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] - self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN') + self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format( + entity_edges, 'RGCN') self.entity_edge_idx = self.entity_edge_idx.to(device) self.entity_edge_type = self.entity_edge_type.to(device) word_edges = side_data['word_kg']['edge'] @@ -85,17 +88,17 @@ def __init__(self, opt, device, vocab, side_data): self.reduction = opt['reduction'] self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) - # selector + # selector self.n_movies = opt['n_movies'] # self.n_movies_label = opt['n_movies_label'] - self.n_movies_label = 64362 # the number of entity2id + self.n_movies_label = 64362 # the number of entity2id # copy mask self.dataset = opt['dataset'] self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) # loss weight self.gen_loss_weight = opt['gen_loss_weight'] super(NTRDModel, self).__init__(opt, device, self.dpath) - + def build_model(self): self._init_embeddings() self._build_kg_layer() @@ -103,31 +106,39 @@ def build_model(self): self._build_recommendation_layer() self._build_conversation_layer() self._build_movie_selector() - + def _init_embeddings(self): if self.pretrained_embedding is not None: self.token_embedding = nn.Embedding.from_pretrained( torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) - - self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx) - nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + self.token_embedding = nn.Embedding( + self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, + mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_( + self.token_embedding.weight[self.pad_token_idx], 0) + + self.word_kg_embedding = nn.Embedding( + self.n_word, self.kg_emb_dim, self.pad_word_idx) + nn.init.normal_(self.word_kg_embedding.weight, + mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0) logger.debug('[Finish init embeddings]') def _build_kg_layer(self): # db encoder - self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) - self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) + self.entity_encoder = RGCNConv( + self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) + self.entity_self_attn = SelfAttentionSeq( + self.kg_emb_dim, self.kg_emb_dim) # concept encoder self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim) - self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) + self.word_self_attn = SelfAttentionSeq( + self.kg_emb_dim, self.kg_emb_dim) # gate mechanism self.gate_layer = GateLayer(self.kg_emb_dim) @@ -148,7 +159,8 @@ def _build_recommendation_layer(self): logger.debug('[Finish build rec layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor( + [self.start_token_idx], dtype=torch.long)) self.conv_encoder = TransformerEncoder( n_heads=self.n_heads, n_layers=self.n_layers, @@ -173,14 +185,14 @@ def _build_conversation_layer(self): self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) - copy_mask = np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool) + copy_mask = np.load(os.path.join( + self.dpath, "copy_mask.npy")).astype(bool) if self.replace_token: if self.replace_token_idx < len(copy_mask): copy_mask[self.replace_token_idx] = False else: - copy_mask = np.insert(copy_mask,len(copy_mask),False) + copy_mask = np.insert(copy_mask, len(copy_mask), False) self.copy_mask = torch.as_tensor(copy_mask).to(self.device) - self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, @@ -208,15 +220,19 @@ def pretrain_infomax(self, batch): if loss_mask.item() == 0: return None - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) word_representations = word_graph_representations[words] word_padding_mask = words.eq(self.pad_word_idx) # (bs, seq_len) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) - info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) + info_predict = F.linear( + word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) loss = self.infomax_loss(info_predict, entity_labels) / loss_mask return loss @@ -236,7 +252,8 @@ def _build_movie_selector(self): embeddings_scale=self.embeddings_scale, n_positions=self.n_positions, ) - self.matching_linear = nn.Linear(self.token_emb_dim,self.n_movies_label) + self.matching_linear = nn.Linear( + self.token_emb_dim, self.n_movies_label) self.sel_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx) def recommend(self, batch, mode): @@ -247,20 +264,27 @@ def recommend(self, batch, mode): """ context_entities, context_words, entities, movie = batch - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, word_len) + entity_padding_mask = context_entities.eq( + self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq( + self.pad_word_idx) # (bs, word_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn( + entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) user_rep = self.gate_layer(entity_attn_rep, word_attn_rep) - rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) + rec_scores = F.linear( + user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) rec_loss = self.rec_loss(rec_scores, movie) @@ -271,7 +295,8 @@ def recommend(self, batch, mode): word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) - info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask + info_loss = self.infomax_loss( + info_predict, entities) / info_loss_mask return rec_loss, info_loss, rec_scores @@ -285,21 +310,27 @@ def freeze_parameters(self): def _starts(self, batch_size): """Return bsz start tokens.""" return self.START.detach().expand(batch_size, 1) - + def converse(self, batch, mode): context_tokens, context_entities, context_words, response, all_movies = batch - entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder( + None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder( + self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, seq_len) + entity_padding_mask = context_entities.eq( + self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq( + self.pad_word_idx) # (bs, seq_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn( + entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn( + word_representations, word_padding_mask) # encoder-decoder tokens_encoding = self.conv_encoder(context_tokens) @@ -309,47 +340,55 @@ def converse(self, batch, mode): conv_word_reps = self.conv_word_norm(word_representations) if mode != 'test': - logits, preds,latent = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, - entity_padding_mask, - conv_word_reps, conv_word_emb, word_padding_mask, - response) + logits, preds, latent = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, + entity_padding_mask, + conv_word_reps, conv_word_emb, word_padding_mask, + response) logits_ = logits.view(-1, logits.shape[-1]) response_ = response.view(-1) gen_loss = self.conv_loss(logits_, response_) - assert torch.sum(all_movies!=0, dim=(0,1)) == torch.sum((response == 30000), dim=(0,1)) #30000 means the idx of [ITEM] - masked_for_selection_token = (response == self.replace_token_idx) + assert torch.sum(all_movies != 0, dim=(0, 1)) == torch.sum( + (response == 30000), dim=(0, 1)) # 30000 means the idx of [ITEM] + masked_for_selection_token = (response == self.replace_token_idx) - matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask) + matching_tensor, _ = self.movie_selector( + latent, tokens_encoding, conv_word_reps, word_padding_mask) matching_logits_ = self.matching_linear(matching_tensor) - matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) + matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze( + -1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) - all_movies = torch.masked_select(all_movies,(all_movies != 0)) - matching_logits = matching_logits.view(-1,matching_logits.shape[-1]) + all_movies = torch.masked_select(all_movies, (all_movies != 0)) + matching_logits = matching_logits.view(-1, + matching_logits.shape[-1]) all_movies = all_movies.view(-1) - selection_loss = self.sel_loss(matching_logits,all_movies) - return gen_loss,selection_loss, preds + selection_loss = self.sel_loss(matching_logits, all_movies) + return gen_loss, selection_loss, preds else: - logits, preds,latent = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, - entity_padding_mask, - conv_word_reps, conv_word_emb, word_padding_mask) - - preds_for_selection = preds[:, 1:] # skip the start_ind - masked_for_selection_token = (preds_for_selection == self.replace_token_idx) - - matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask) + logits, preds, latent = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, + entity_padding_mask, + conv_word_reps, conv_word_emb, word_padding_mask) + + preds_for_selection = preds[:, 1:] # skip the start_ind + masked_for_selection_token = ( + preds_for_selection == self.replace_token_idx) + + matching_tensor, _ = self.movie_selector( + latent, tokens_encoding, conv_word_reps, word_padding_mask) matching_logits_ = self.matching_linear(matching_tensor) - matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) + matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze( + -1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) if matching_logits.shape[0] is not 0: - #W1: greedy - _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] + #W1: greedy + _, matching_pred = matching_logits.max( + dim=-1) # [bsz * dynamic_movie_nums] else: matching_pred = None - return preds,matching_pred,matching_logits_ - + return preds, matching_pred, matching_logits_ + def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask, word_reps, word_emb_attn, word_mask): batch_size = token_encoding[0].shape[0] @@ -364,16 +403,19 @@ def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e latents.append(dialog_latent) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm( + torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + copy_logits = self.copy_output( + copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1).long() logits.append(sum_logits) inputs = torch.cat((inputs, preds), dim=1) - finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size + finished = ((inputs == self.end_token_idx).sum( + dim=-1) > 0).sum().item() == batch_size if finished: break logits = torch.cat(logits, dim=1) @@ -388,7 +430,7 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e dialog_latent, _ = self.conv_decoder(inputs, token_encoding, word_reps, word_mask, entity_reps, entity_mask) # (bs, seq_len, dim) - + entity_latent = entity_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) word_latent = word_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) copy_latent = self.copy_norm( @@ -396,18 +438,19 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze( 0) # (bs, seq_len, vocab_size) - gen_logits = F.linear(dialog_latent, self.token_embedding.weight) # (bs, seq_len, vocab_size) + # (bs, seq_len, vocab_size) + gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1) return sum_logits, preds, dialog_latent - - def forward(self, batch, stage, mode): if len(self.gpu) >= 2: # forward function operates on different gpus, the weight of graph network need to be copied to other gpu - self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device()) - self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device()) + self.entity_edge_idx = self.entity_edge_idx.cuda( + torch.cuda.current_device()) + self.entity_edge_type = self.entity_edge_type.cuda( + torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), ).cuda(torch.cuda.current_device()) @@ -417,4 +460,4 @@ def forward(self, batch, stage, mode): loss = self.recommend(batch, mode) elif stage == "conv": loss = self.converse(batch, mode) - return loss \ No newline at end of file + return loss diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index f202dcb..6f436e7 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -15,9 +15,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence - from crslab.model.utils.functions import sort_for_packed_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class HRNN(nn.Module): @@ -62,10 +61,12 @@ def get_utterance_encoding(self, context, utterance_lengths): """ batch_size, max_conv_length = context.shape[:2] utterance_lengths = utterance_lengths.reshape(-1) # (bs * conv_len) - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(utterance_lengths) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( + utterance_lengths) # reshape and reorder - sorted_utterances = context.view(batch_size * max_conv_length, -1).index_select(0, sorted_idx) + sorted_utterances = context.view( + batch_size * max_conv_length, -1).index_select(0, sorted_idx) # consider valid sequences only(length > 0) num_positive_lengths = torch.sum(utterance_lengths > 0) @@ -76,18 +77,21 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence(embedded, sorted_lengths.cpu(), batch_first=True) + packed_utterances = pack_padded_sequence( + embedded, sorted_lengths.cpu(), batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) - utterance_encoding = torch.cat((utterance_encoding[-1], utterance_encoding[-2]), 1) + utterance_encoding = torch.cat( + (utterance_encoding[-1], utterance_encoding[-2]), 1) if self.use_dropout: utterance_encoding = self.dropout(utterance_encoding) # complete the missing sequences (of length 0) if num_positive_lengths < batch_size * max_conv_length: pad_tensor = utterance_encoding.new_full( - (batch_size * max_conv_length - num_positive_lengths, 2 * self.utterance_encoder_hidden_size), + (batch_size * max_conv_length - num_positive_lengths, + 2 * self.utterance_encoder_hidden_size), self.pad_token_idx) utterance_encoding = torch.cat((utterance_encoding, pad_tensor), 0) @@ -104,12 +108,15 @@ def forward(self, context, utterance_lengths, dialog_lengths): :param dialog_lengths: (batch_size) :return context_state: (batch_size, context_encoder_hidden_size) """ - utterance_encoding = self.get_utterance_encoding(context, utterance_lengths) # (bs, conv_len, 2 * utt_dim) - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(dialog_lengths) + utterance_encoding = self.get_utterance_encoding( + context, utterance_lengths) # (bs, conv_len, 2 * utt_dim) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( + dialog_lengths) # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths.cpu(), batch_first=True) + packed_sequences = pack_padded_sequence( + sorted_representations, sorted_lengths.cpu(), batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -146,10 +153,13 @@ def forward(self, request, request_lengths, context_state): batch_size, max_utterance_length = request.shape # sort for pack - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( + request_lengths) sorted_request = request.index_select(0, sorted_idx) - embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) - packed_request = pack_padded_sequence(embedded_request, sorted_lengths.cpu(), batch_first=True) + # (batch_size, max_utterance_length, embed_dim) + embedded_request = self.embedding(sorted_request) + packed_request = pack_padded_sequence( + embedded_request, sorted_lengths.cpu(), batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( @@ -165,7 +175,8 @@ def forward(self, request, request_lengths, context_state): (batch_size, max_utterance_length - max_request_length, decoder_hidden_size), self.pad_token_idx) sorted_vocab_state = torch.cat((sorted_vocab_state, pad_tensor), dim=1) # (batch_size, max_utterance_length, decoder_hidden_size) - sorted_language_output = self.out(sorted_vocab_state) # (batch_size, max_utterance_length, vocab_size) + # (batch_size, max_utterance_length, vocab_size) + sorted_language_output = self.out(sorted_vocab_state) # expand context to each time step expanded_sorted_context_state = sorted_context_state.unsqueeze(1).expand( @@ -174,12 +185,15 @@ def forward(self, request, request_lengths, context_state): # compute switch switch_input = torch.cat((expanded_sorted_context_state, sorted_vocab_state), dim=2) # (batch_size, max_utterance_length, context_size + decoder_hidden_size) - switch = self.switch(switch_input) # (batch_size, max_utterance_length, 1) + # (batch_size, max_utterance_length, 1) + switch = self.switch(switch_input) sorted_output = torch.cat(( - F.logsigmoid(switch) + F.log_softmax(sorted_language_output, dim=2), + F.logsigmoid(switch) + + F.log_softmax(sorted_language_output, dim=2), F.logsigmoid(-switch) # for item ), dim=2) - output = sorted_output.index_select(0, rev_idx) # (batch_size, max_utterance_length, vocab_size + 1) + # (batch_size, max_utterance_length, vocab_size + 1) + output = sorted_output.index_select(0, rev_idx) return output diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index a98975c..3995b52 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -23,20 +23,16 @@ """ -import os import torch +from crslab.model.base import BaseModel from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class TGConvModel(BaseModel): """ - + Attributes: context_truncate: A integer indicating the length of dialogue context. response_truncate: A integer indicating the length of dialogue response. @@ -98,7 +94,8 @@ def generate(self, context): context = context[..., -self.response_truncate + 1:] for i in range(self.response_truncate - 1): - outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s), + outputs = self.model( + context, former_hidden_state) # (bs, c_t, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s) @@ -131,8 +128,10 @@ def generate_bs(self, context, beam=4): next_token_logits = last_hidden_state[:, -1, :] next_token_probs = torch.nn.functional.softmax(next_token_logits) topk = torch.topk(next_token_probs, beam, dim=-1) - probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam) - preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam) + probs = topk.values.reshape( + [batch_size, -1, beam]) # (bs, candidate, beam) + preds = topk.indices.reshape( + [batch_size, -1, beam]) # (bs, candidate, beam) for j in range(batch_size): all_candidates = [] @@ -144,7 +143,8 @@ def generate_bs(self, context, beam=4): seq_tmp.append(preds[j][n][k]) candidate = [seq_tmp, prob * probs[j][n][k]] all_candidates.append(candidate) - ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) + ordered = sorted( + all_candidates, key=lambda tup: tup[1], reverse=True) sequences[j] = ordered[:beam] res = [] diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 92fb0a0..37c21e8 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -23,16 +23,12 @@ """ -import os import torch +from crslab.model.base import BaseModel from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class TGPolicyModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -43,7 +39,7 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) @@ -76,11 +72,13 @@ def forward(self, batch, mode): tp_mask).pooler_output # (bs, hidden_size) bs = user_profile.shape[0] // self.n_sent - profile_rep = self.profile_bert(user_profile, profile_mask).pooler_output # (bs, word_num, hidden) + profile_rep = self.profile_bert( + user_profile, profile_mask).pooler_output # (bs, word_num, hidden) profile_rep = profile_rep.view(bs, self.n_sent, -1) profile_rep = torch.mean(profile_rep, dim=1) # (bs, hidden) - state_rep = torch.cat((context_rep, topic_rep, profile_rep), dim=1) # [bs, hidden_size*3] + # [bs, hidden_size*3] + state_rep = torch.cat((context_rep, topic_rep, profile_rep), dim=1) topic_scores = self.state2topic_id(state_rep) topic_loss = self.loss(topic_scores, y) diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index 4f4d592..36ce1d4 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -23,22 +23,18 @@ """ -import os import torch +from crslab.model.base import BaseModel +from crslab.model.recommendation.sasrec.modules import SASRec from loguru import logger from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel -from crslab.model.recommendation.sasrec.modules import SASRec - class TGRecModel(BaseModel): """ - + Attributes: hidden_dropout_prob: A float indicating the dropout rate to dropout hidden state in SASRec. initializer_range: A float indicating the range of parameters initization in SASRec. @@ -98,7 +94,8 @@ def forward(self, batch, mode): bert_embed = self.bert(context, attention_mask=mask).pooler_output - sequence_output = self.SASREC(input_ids, input_mask) # bs, max_len, hidden_size2 + # bs, max_len, hidden_size2 + sequence_output = self.SASREC(input_ids, input_mask) sas_embed = sequence_output[:, -1, :] # bs, hidden_size2 embed = torch.cat((sas_embed, bert_embed), dim=1) diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 663656b..374014f 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -23,15 +23,11 @@ """ -import os +from crslab.model.base import BaseModel from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class ConvBERTModel(BaseModel): """ @@ -49,7 +45,7 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] self.dpath = opt['policy_pretrained_path'] diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index 8b4c69a..acce39a 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -23,16 +23,12 @@ """ -import os import torch +from crslab.model.base import BaseModel from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class ProfileBERTModel(BaseModel): """ @@ -51,7 +47,7 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index 7a9181b..cb47bc0 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -23,15 +23,11 @@ """ -import os +from crslab.model.base import BaseModel from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class TopicBERTModel(BaseModel): """ @@ -50,7 +46,7 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index 7e7eb80..24b6670 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -23,16 +23,12 @@ """ -import os +from crslab.model.base import BaseModel from loguru import logger from torch import nn from transformers import BertModel -from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map -from crslab.model.base import BaseModel - class BERTModel(BaseModel): """ diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 2199396..8627c46 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -8,8 +8,7 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -from crslab.config import Config -from crslab.data import get_dataset, get_dataloader +from crslab.data import get_dataloader, get_dataset from crslab.system import get_system @@ -34,12 +33,15 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data, task=None) + CRS_dataset = get_dataset( + config, config['tokenize'], restore_data, save_data, task=None) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab - train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab) - valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab) + train_dataloader = get_dataloader( + config, CRS_dataset.train_data, vocab) + valid_dataloader = get_dataloader( + config, CRS_dataset.valid_data, vocab) test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab) else: tokenized_dataset = {} @@ -53,7 +55,8 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - dataset = get_dataset(config, tokenize, restore_data, save_data, task) + dataset = get_dataset( + config, tokenize, restore_data, save_data, task) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data @@ -61,9 +64,12 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r side_data[task] = dataset.side_data vocab[task] = dataset.vocab - train_dataloader[task] = get_dataloader(config, train_data, vocab[task]) - valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task]) - test_dataloader[task] = get_dataloader(config, test_data, vocab[task]) + train_dataloader[task] = get_dataloader( + config, train_data, vocab[task]) + valid_dataloader[task] = get_dataloader( + config, valid_data, vocab[task]) + test_dataloader[task] = get_dataloader( + config, test_data, vocab[task]) # system CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) diff --git a/crslab/system/inspired.py b/crslab/system/inspired.py index 3fb77ed..5a5f9f1 100644 --- a/crslab/system/inspired.py +++ b/crslab/system/inspired.py @@ -2,15 +2,14 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -import torch -from loguru import logger from math import floor -from crslab.data import dataset_language_map +import torch from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt +from loguru import logger class InspiredSystem(BaseSystem): @@ -54,14 +53,14 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.conv_epoch = self.conv_optim_opt['epoch'] self.conv_batch_size = self.conv_optim_opt['batch_size'] if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][ - 'name']: + 'name']: batch_num = 0 for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False): batch_num += 1 - conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) + conv_training_steps = self.conv_epoch * \ + floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps - def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -158,7 +157,8 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + \ + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index bb839a5..f5e2cf7 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -15,12 +15,11 @@ import os import torch -from loguru import logger - from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt +from loguru import logger class KGSFSystem(BaseSystem): @@ -85,9 +84,11 @@ def step(self, batch, stage, mode): if info_loss is not None: self.backward(info_loss.sum()) info_loss = info_loss.sum().item() - self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) + self.evaluator.optim_metrics.add( + "info_loss", AverageMetric(info_loss)) elif stage == 'rec': - rec_loss, info_loss, rec_predict = self.model.forward(batch, stage, mode) + rec_loss, info_loss, rec_predict = self.model.forward( + batch, stage, mode) if info_loss: loss = rec_loss + 0.025 * info_loss else: @@ -97,10 +98,12 @@ def step(self, batch, stage, mode): else: self.rec_evaluate(rec_predict, batch[-1]) rec_loss = rec_loss.sum().item() - self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) + self.evaluator.optim_metrics.add( + "rec_loss", AverageMetric(rec_loss)) if info_loss: info_loss = info_loss.sum().item() - self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) + self.evaluator.optim_metrics.add( + "info_loss", AverageMetric(info_loss)) elif stage == "conv": if mode != "test": gen_loss, pred = self.model.forward(batch, stage, mode) @@ -109,7 +112,8 @@ def step(self, batch, stage, mode): else: self.conv_evaluate(pred, batch[-1]) gen_loss = gen_loss.sum().item() - self.evaluator.optim_metrics.add("gen_loss", AverageMetric(gen_loss)) + self.evaluator.optim_metrics.add( + "gen_loss", AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: pred = self.model.forward(batch, stage, mode) @@ -145,7 +149,8 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + \ + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test diff --git a/crslab/system/redial.py b/crslab/system/redial.py index a7fa758..967cf14 100644 --- a/crslab/system/redial.py +++ b/crslab/system/redial.py @@ -8,13 +8,11 @@ # @email : wxl1999@foxmail.com import torch -from loguru import logger - -from crslab.data import dataset_language_map from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt +from loguru import logger class ReDialSystem(BaseSystem): @@ -51,7 +49,6 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.rec_batch_size = self.rec_optim_opt['batch_size'] self.conv_batch_size = self.conv_optim_opt['batch_size'] - def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -86,7 +83,8 @@ def step(self, batch, stage, mode): else: self.rec_evaluate(rec_scores, batch['item']) rec_loss = rec_loss.item() - self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) + self.evaluator.optim_metrics.add( + "rec_loss", AverageMetric(rec_loss)) else: gen_loss, preds = self.conv_model.forward(batch, mode=mode) gen_loss = gen_loss.sum() @@ -95,7 +93,8 @@ def step(self, batch, stage, mode): else: self.conv_evaluate(preds, batch['response']) gen_loss = gen_loss.item() - self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss)) + self.evaluator.optim_metrics.add( + 'gen_loss', AverageMetric(gen_loss)) self.evaluator.gen_metrics.add('ppl', PPLMetric(gen_loss)) def train_recommender(self): @@ -107,14 +106,16 @@ def train_recommender(self): logger.info('[Train]') for batch in self.train_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size): self.step(batch, stage='rec', mode='train') - self.evaluator.report(epoch=epoch, mode='train') # report train loss + # report train loss + self.evaluator.report(epoch=epoch, mode='train') # val logger.info('[Valid]') with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.valid_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size, shuffle=False): self.step(batch, stage='rec', mode='valid') - self.evaluator.report(epoch=epoch, mode='valid') # report valid loss + # report valid loss + self.evaluator.report(epoch=epoch, mode='valid') # early stop metric = self.evaluator.optim_metrics['rec_loss'] if self.early_stop(metric): diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 96251c5..4149b2d 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -13,17 +13,16 @@ # @Email : txy20010310@163.com import os - -import torch -from loguru import logger from math import floor +import torch from crslab.config import PRETRAIN_PATH -from crslab.data import get_dataloader, dataset_language_map +from crslab.data import dataset_language_map, get_dataloader from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt +from loguru import logger class TGReDialSystem(BaseSystem): @@ -67,11 +66,12 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.conv_epoch = self.conv_optim_opt['epoch'] self.conv_batch_size = self.conv_optim_opt['batch_size'] if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][ - 'name']: + 'name']: batch_num = 0 for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False): batch_num += 1 - conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) + conv_training_steps = self.conv_epoch * \ + floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps if hasattr(self, 'policy_model'): @@ -126,7 +126,8 @@ def step(self, batch, stage, mode): else: self.policy_model.eval() - policy_loss, policy_predict = self.policy_model.forward(batch, mode) + policy_loss, policy_predict = self.policy_model.forward( + batch, mode) if mode == "train" and policy_loss is not None: policy_loss = policy_loss.sum() self.backward(policy_loss) @@ -177,7 +178,8 @@ def train_recommender(self): elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: bert_param = list(self.rec_model.bert.named_parameters()) else: - bert_param = list(self.rec_model.module.bert.named_parameters()) + bert_param = list( + self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] else: bert_param = [] @@ -205,7 +207,8 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + \ + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test @@ -279,7 +282,8 @@ def train_policy(self): self.step(batch, stage='policy', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + \ + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test @@ -314,7 +318,8 @@ def interact(self): for r in rank.tolist(): item_ids.append(self.item_ids[r]) first_item_id = item_ids[:1] - self.update_context('rec', entity_ids=first_item_id, item_ids=first_item_id) + self.update_context( + 'rec', entity_ids=first_item_id, item_ids=first_item_id) print(f"[Recommend]:") for item_id in item_ids: @@ -323,18 +328,22 @@ def interact(self): # conv if hasattr(self, 'conv_model'): conv_input = self.process_input(input_text, 'conv') - preds = self.conv_model.forward(conv_input, 'infer').tolist()[0] + preds = self.conv_model.forward( + conv_input, 'infer').tolist()[0] p_str = ind2txt(preds, self.ind2tok, self.end_token_idx) - token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(p_str, 'conv') - self.update_context('conv', token_ids, entity_ids, movie_ids, word_ids) + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id( + p_str, 'conv') + self.update_context('conv', token_ids, + entity_ids, movie_ids, word_ids) print(f"[Response]:\n{p_str}") # input input_text = self.get_input(self.language) def process_input(self, input_text, stage): - token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(input_text, stage) + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id( + input_text, stage) self.update_context(stage, token_ids, entity_ids, movie_ids, word_ids) data = {'role': 'Seeker', 'context_tokens': self.context[stage]['context_tokens'], @@ -349,7 +358,8 @@ def process_input(self, input_text, stage): elif stage == 'conv': data = dataloader.conv_interact(data) - data = [ele.to(self.device) if isinstance(ele, torch.Tensor) else ele for ele in data] + data = [ele.to(self.device) if isinstance( + ele, torch.Tensor) else ele for ele in data] return data def convert_to_id(self, text, stage): @@ -360,18 +370,23 @@ def convert_to_id(self, text, stage): else: raise - entities = self.link(tokens, self.side_data[stage]['entity_kg']['entity']) + entities = self.link( + tokens, self.side_data[stage]['entity_kg']['entity']) words = self.link(tokens, self.side_data[stage]['word_kg']['entity']) if self.opt['tokenize'][stage] in ('gpt2', 'bert'): language = dataset_language_map[self.opt['dataset']] - path = os.path.join(PRETRAIN_PATH, self.opt['tokenize'][stage], language) + path = os.path.join( + PRETRAIN_PATH, self.opt['tokenize'][stage], language) tokens = self.tokenize(text, 'bert', path) - token_ids = [self.vocab[stage]['tok2ind'].get(token, self.vocab[stage]['unk']) for token in tokens] + token_ids = [self.vocab[stage]['tok2ind'].get( + token, self.vocab[stage]['unk']) for token in tokens] entity_ids = [self.vocab[stage]['entity2id'][entity] for entity in entities if entity in self.vocab[stage]['entity2id']] - movie_ids = [entity_id for entity_id in entity_ids if entity_id in self.item_ids] - word_ids = [self.vocab[stage]['word2id'][word] for word in words if word in self.vocab[stage]['word2id']] + movie_ids = [ + entity_id for entity_id in entity_ids if entity_id in self.item_ids] + word_ids = [self.vocab[stage]['word2id'][word] + for word in words if word in self.vocab[stage]['word2id']] return token_ids, entity_ids, movie_ids, word_ids From 1eb81e439aa924fc1409f834954d502bc691b105 Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 7 Oct 2022 11:04:47 +0800 Subject: [PATCH 08/35] txy77 --- .vscode/settings.json | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index fe95ac3..e69de29 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +0,0 @@ -{ - "editor.formatOnPaste": true, - "editor.formatOnSave": true -} \ No newline at end of file From 0b0436e7157b45855259223f24d85e8d32c0630b Mon Sep 17 00:00:00 2001 From: txy <55396195+txy77@users.noreply.github.com> Date: Fri, 7 Oct 2022 11:07:04 +0800 Subject: [PATCH 09/35] Delete settings.json --- .vscode/settings.json | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index e69de29..0000000 From 735c3aa980e45644e54b6e31c894d0e6fe861f3a Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 18:55:39 +0800 Subject: [PATCH 10/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 30 ++++++++++---------- crslab/data/dataset/gorecdial/gorecdial.py | 30 ++++++++++---------- crslab/data/dataset/inspired/inspired.py | 30 ++++++++++---------- crslab/data/dataset/opendialkg/opendialkg.py | 30 ++++++++++---------- crslab/data/dataset/redial/redial.py | 30 ++++++++++---------- crslab/data/dataset/tgredial/tgredial.py | 30 ++++++++++---------- 6 files changed, 90 insertions(+), 90 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index e94b8d9..3a8cbdf 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -119,37 +119,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): self.tok2ind = json.load( @@ -323,7 +323,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): each_dict = {} @@ -375,11 +375,11 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] @@ -404,10 +404,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 6c84031..27c1eee 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -120,37 +120,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): self.tok2ind = json.load( @@ -329,7 +329,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): each_dict = {} @@ -384,13 +384,13 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] @@ -418,10 +418,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 860fc29..b8df35e 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -120,37 +120,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: @@ -328,7 +328,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): @@ -384,13 +384,13 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] @@ -426,10 +426,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 482cc89..c9daf14 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -121,37 +121,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): self.tok2ind = json.load( @@ -335,7 +335,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): @@ -392,13 +392,13 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] @@ -426,10 +426,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 5f28d70..15dbc9a 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -121,37 +121,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): self.tok2ind = json.load( @@ -332,7 +332,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): @@ -388,13 +388,13 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] @@ -421,10 +421,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['dialog']: text = dialog['text'] corpus.append(text) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 05d450b..42493cc 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -138,37 +138,37 @@ def _load_raw_data(self): train_data = json.load(f) logger.debug( f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split token - processing_train_data = self.split_token(train_data) + # split text + processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processing_train_data) + tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processing_train_data) + self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processing_train_data) + copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) logger.debug( f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_token - processing_valid_data = self.split_token(valid_data) + # split_text + processed_valid_data = self.split_text(valid_data) logger.info("[Finish valid data split]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) logger.debug( f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_token - processing_test_data = self.split_token(test_data) + # split_text + processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processing_train_data, processing_valid_data, processing_test_data + return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): self.tok2ind = json.load( @@ -418,7 +418,7 @@ def _word_kg_process(self): 'entity': list(entities) } - def split_token(self, data): + def split_text(self, data): all_data = [] for each in tqdm(data): @@ -476,13 +476,13 @@ def generate_tok2ind(self, processed_train_data): return tok2ind - def generate_copy_mask(self, tok2ind, processing_train_data): + def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize crstokenize = self.crstokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processing_train_data): + for each_data in tqdm(processed_train_data): for dialog in each_data['messages']: match_list = [] text = dialog['text'] @@ -511,10 +511,10 @@ def generate_copy_mask(self, tok2ind, processing_train_data): np.save(path, copy_mask) - def generate_word2vec(self, processing_train_data): + def generate_word2vec(self, processed_train_data): corpus = [] - for each_data in processing_train_data: + for each_data in processed_train_data: for dialog in each_data['messages']: text = dialog['text'] corpus.append(text) From d775117b81bc34aad381e0399749ab44b12f93f3 Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 19:12:06 +0800 Subject: [PATCH 11/35] txy77 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 05950a5..c133145 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 -gensim +gensim~=4.2.0 From 7f43abe94a1168e528e0c5b2208b37c59e09a047 Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 19:14:58 +0800 Subject: [PATCH 12/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 6 +++--- crslab/data/dataset/gorecdial/gorecdial.py | 6 +++--- crslab/data/dataset/inspired/inspired.py | 6 +++--- crslab/data/dataset/opendialkg/opendialkg.py | 6 +++--- crslab/data/dataset/redial/redial.py | 6 +++--- crslab/data/dataset/tgredial/tgredial.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 3a8cbdf..65e9d64 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -91,7 +91,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) @@ -331,7 +331,7 @@ def split_text(self, data): for one in each['dialog']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -377,7 +377,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 27c1eee..bd15016 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -91,7 +91,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) @@ -337,7 +337,7 @@ def split_text(self, data): for one in each['dialog']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -387,7 +387,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index b8df35e..f8723a6 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -91,7 +91,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) @@ -337,7 +337,7 @@ def split_text(self, data): for one in each['dialog']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -387,7 +387,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index c9daf14..9be5841 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) @@ -344,7 +344,7 @@ def split_text(self, data): for one in each['dialog']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -395,7 +395,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 15dbc9a..a30a485 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) @@ -341,7 +341,7 @@ def split_text(self, data): for one in each['dialog']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -391,7 +391,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 42493cc..d4b29de 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -97,7 +97,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): if task_tokenize_path in opt: self.tokenize_path = opt[task_tokenize_path] self.tokenize_class = globals()[tokenize + '_tokenize'] - self.crstokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, 'tgredial') self.replace_token = opt.get('replace_token', None) @@ -428,7 +428,7 @@ def split_text(self, data): for one in each['messages']: str_text = one['text'] tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text each_data.append(one) @@ -479,7 +479,7 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): tokenizer = self.tokenize - crstokenize = self.crstokenizer + crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): From 36e120d734326afcefa5a59bee39fe104cd00d52 Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 19:29:35 +0800 Subject: [PATCH 13/35] txy77 --- crslab/model/crs/inspired/inspired_conv.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 19e58ad..af91a61 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -67,10 +67,7 @@ def converse(self, batch, mode): past = None lm_logits_all = [] - - GPT2_Config = GPT2Config.from_pretrained(self.dpath) - - support_up_limits = GPT2_Config.n_positions + support_up_limits = self.model_sk.config.n_positions if mode != 'test': for turn, iter in enumerate(input_ids_iters): From d245105e293c9d55b8f527a7e5564ad07266bd00 Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 20:34:25 +0800 Subject: [PATCH 14/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 4 ++-- crslab/data/dataset/gorecdial/gorecdial.py | 4 ++-- crslab/data/dataset/inspired/inspired.py | 4 ++-- crslab/data/dataset/opendialkg/opendialkg.py | 4 ++-- crslab/data/dataset/redial/redial.py | 4 ++-- crslab/data/dataset/tgredial/tgredial.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 65e9d64..b029d5e 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -398,11 +398,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index bd15016..5ebef85 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -412,11 +412,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index f8723a6..29aaaa1 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -420,11 +420,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 9be5841..173a10d 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -420,11 +420,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index a30a485..80a79cf 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -415,11 +415,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index d4b29de..972a362 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -505,11 +505,11 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial', 'copy_mask.npy') + path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') if not os.path.exists(path): os.makedirs(path) - np.save(path, copy_mask) + np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) def generate_word2vec(self, processed_train_data): From 0fcd3d351942a316520f91bd454f2e6958a4b9fe Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 11 Oct 2022 20:47:59 +0800 Subject: [PATCH 15/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 11 +++++++++-- crslab/data/dataset/gorecdial/gorecdial.py | 11 +++++++++-- crslab/data/dataset/inspired/inspired.py | 11 +++++++++-- crslab/data/dataset/opendialkg/opendialkg.py | 11 +++++++++-- crslab/data/dataset/redial/redial.py | 11 +++++++++-- crslab/data/dataset/tgredial/tgredial.py | 11 +++++++++-- 6 files changed, 54 insertions(+), 12 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index b029d5e..99535ae 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -81,6 +81,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -126,8 +132,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self.generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 5ebef85..8dd8a2c 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -81,6 +81,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -127,8 +133,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self.generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 29aaaa1..011ee5b 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -81,6 +81,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -127,8 +133,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self.generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 173a10d..8f7ce72 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -82,6 +82,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -128,8 +134,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self. generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 80a79cf..e62c438 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -82,6 +82,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -128,8 +134,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self.generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 972a362..3e4bd8d 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -85,6 +85,12 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.copy = True else: self.copy = False + + if 'embedding' in opt: + self.generate_embedding = True + else: + self.generate_embedding = False + resource = resources['resource'] token = resource[tokenize] self.special_token_idx = token['special_token_idx'] @@ -145,8 +151,9 @@ def _load_raw_data(self): tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') + if self.generate_embedding: + self.generate_word2vec(processed_train_data) + logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) From 53427eae64e9d9ba8ba9ec5b0790b88ded410a9b Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 00:07:58 +0800 Subject: [PATCH 16/35] txy77 --- crslab/data/__init__.py | 20 ++++++++++++++++++-- crslab/data/dataset/durecdial/durecdial.py | 16 ++-------------- crslab/data/dataset/gorecdial/gorecdial.py | 16 ++-------------- crslab/data/dataset/inspired/inspired.py | 16 ++-------------- crslab/data/dataset/opendialkg/opendialkg.py | 16 ++-------------- crslab/data/dataset/redial/redial.py | 15 ++------------- crslab/data/dataset/tgredial/tgredial.py | 16 ++-------------- crslab/data/dataset/tokenizer/__init__.py | 6 ++++++ crslab/quick_start/quick_start.py | 12 +++++++++--- 9 files changed, 45 insertions(+), 88 deletions(-) create mode 100644 crslab/data/dataset/tokenizer/__init__.py diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index cca2126..31272ab 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -22,6 +22,15 @@ from crslab.data.dataloader import * from crslab.data.dataset import * +from crslab.data.dataset.tokenizer import * + +tokenizer_register_table = { + 'nltk': nltk_tokenize, + 'jieba': jieba_tokenize, + 'gpt2': gpt2_tokenize, + 'bert': bert_tokenize, + 'pkuseg': pkuseg_tokenize +} dataset_register_table = { 'ReDial': ReDialDataset, @@ -70,7 +79,14 @@ } -def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: +def get_tokenizer(tokenize, path=None) -> BaseCrsTokenize: + """ + get tokenizer from opt + """ + return tokenizer_register_table[tokenize](path) + + +def get_dataset(opt, tokenize, CRS_Tokenizer, restore, save) -> BaseDataset: """get and process dataset Args: @@ -85,7 +101,7 @@ def get_dataset(opt, tokenize, restore, save, task=None) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, restore, save, task) + return dataset_register_table[dataset](opt, tokenize, CRS_Tokenizer, restore, save) else: raise NotImplementedError( f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 99535ae..037b06f 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -31,11 +31,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -67,7 +62,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """ Args: @@ -92,12 +87,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) @@ -337,7 +327,6 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -383,7 +372,6 @@ def generate_tok2ind(self, processed_train_data): return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 8dd8a2c..cd53d56 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -31,11 +31,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -67,7 +62,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -92,12 +87,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) @@ -343,7 +333,6 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -393,7 +382,6 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 011ee5b..8b65b78 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -31,11 +31,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -67,7 +62,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -92,12 +87,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) @@ -343,7 +333,6 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -393,7 +382,6 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8f7ce72..5a94fd6 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -32,11 +32,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -68,7 +63,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -93,12 +88,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) @@ -350,7 +340,6 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -401,7 +390,6 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index e62c438..1dcf819 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -32,11 +32,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -68,7 +63,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -93,11 +88,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] + self.tokenize_class = CRS_Tokenizer self.Tokenizer = self.tokenize_class(self.tokenize_path) dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) @@ -347,7 +338,6 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -397,7 +387,6 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 3e4bd8d..5da64d1 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -32,11 +32,6 @@ import numpy as np from crslab.config import DATASET_PATH, MODEL_PATH from crslab.data.dataset.base import BaseDataset -from crslab.data.dataset.tokenizer.bert import bert_tokenize -from crslab.data.dataset.tokenizer.gpt2 import gpt2_tokenize -from crslab.data.dataset.tokenizer.jieba import jieba_tokenize -from crslab.data.dataset.tokenizer.nltk import nltk_tokenize -from crslab.data.dataset.tokenizer.pkuseg import pkuseg_tokenize from loguru import logger from tqdm import tqdm @@ -71,7 +66,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, restore=False, save=False, task=None): + def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -98,12 +93,7 @@ def __init__(self, opt, tokenize, restore=False, save=False, task=None): self.pad_topic_idx = self.special_token_idx['pad_topic'] self.tokenize = tokenize - task_tokenize_path = str(task) + '_tokenize_path' - self.tokenize_path = None - if task_tokenize_path in opt: - self.tokenize_path = opt[task_tokenize_path] - self.tokenize_class = globals()[tokenize + '_tokenize'] - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, 'tgredial') self.replace_token = opt.get('replace_token', None) @@ -434,7 +424,6 @@ def split_text(self, data): each_dict['conv_id'] = each['conv_id'] for one in each['messages']: str_text = one['text'] - tokenizer = self.tokenize crstokenize = self.Tokenizer list_text = crstokenize.tokenize(str_text) one['text'] = list_text @@ -485,7 +474,6 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - tokenizer = self.tokenize crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) diff --git a/crslab/data/dataset/tokenizer/__init__.py b/crslab/data/dataset/tokenizer/__init__.py new file mode 100644 index 0000000..3b7ef7a --- /dev/null +++ b/crslab/data/dataset/tokenizer/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseCrsTokenize +from .bert import bert_tokenize +from .gpt2 import gpt2_tokenize +from .jieba import jieba_tokenize +from .nltk import nltk_tokenize +from .pkuseg import pkuseg_tokenize diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 8627c46..62a1cfc 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -8,7 +8,7 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -from crslab.data import get_dataloader, get_dataset +from crslab.data import get_dataloader, get_dataset, get_tokenizer from crslab.system import get_system @@ -33,8 +33,9 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): + CRS_Tokenizer = get_tokenizer(config['tokenize'], path=None) CRS_dataset = get_dataset( - config, config['tokenize'], restore_data, save_data, task=None) + config, config['tokenize'], CRS_Tokenizer, restore_data, save_data) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -55,8 +56,13 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: + task_tokenize_path = str(task) + '_tokenize_path' + tokenize_path = None + if task_tokenize_path in config: + tokenize_path = config[task_tokenize_path] + CRS_Tokenizer = get_tokenizer(tokenize, tokenize_path) dataset = get_dataset( - config, tokenize, restore_data, save_data, task) + config, tokenize, CRS_Tokenizer, restore_data, save_data) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data From 8ac3b83d73b9f46027aae2aab84d0b6b312615fd Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 12:47:31 +0800 Subject: [PATCH 17/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 13 +++---------- crslab/data/dataset/gorecdial/gorecdial.py | 13 +++---------- crslab/data/dataset/inspired/inspired.py | 13 +++---------- crslab/data/dataset/opendialkg/opendialkg.py | 14 +++----------- crslab/data/dataset/redial/redial.py | 13 +++---------- crslab/data/dataset/tgredial/tgredial.py | 13 +++---------- 6 files changed, 18 insertions(+), 61 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 037b06f..d2a2218 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -119,7 +119,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self.generate_embedding: @@ -127,7 +127,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -149,12 +149,10 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - self.tok2ind = json.load( - open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -364,11 +362,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'durecdial', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index cd53d56..e8e6671 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -120,7 +120,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self.generate_embedding: @@ -128,7 +128,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -150,12 +150,10 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - self.tok2ind = json.load( - open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -373,11 +371,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'gorecdial', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 8b65b78..098a85b 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -120,7 +120,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self.generate_embedding: @@ -128,7 +128,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -150,12 +150,10 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: - self.tok2ind = json.load(f) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -373,11 +371,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'inspired', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 5a94fd6..8078421 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -121,7 +121,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self. generate_embedding: @@ -129,7 +129,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -151,12 +151,10 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - self.tok2ind = json.load( - open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -380,12 +378,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join( - DATASET_PATH, 'opendialkg', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 1dcf819..84c6e90 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -122,7 +122,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self.generate_embedding: @@ -130,7 +130,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -152,12 +152,10 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - self.tok2ind = json.load( - open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -378,11 +376,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'redial', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 5da64d1..b357732 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -138,7 +138,7 @@ def _load_raw_data(self): processed_train_data = self.split_text(train_data) logger.info("[Finish train data split]") # generate tok2ind - tok2ind = self.generate_tok2ind(processed_train_data) + self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec if self.generate_embedding: @@ -146,7 +146,7 @@ def _load_raw_data(self): logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: - copy_mask = self.generate_copy_mask(tok2ind, processed_train_data) + copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: @@ -168,8 +168,6 @@ def _load_raw_data(self): return processed_train_data, processed_valid_data, processed_test_data def _load_vocab(self): - self.tok2ind = json.load( - open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} # add special tokens if self.replace_token: @@ -184,7 +182,7 @@ def _load_vocab(self): self.special_token_idx[self.replace_token] = len( self.tok2ind)-1 logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + f"[Load vocab from token2id]") logger.debug( f"[The size of token2index dictionary is {len(self.tok2ind)}]") logger.debug( @@ -465,11 +463,6 @@ def generate_tok2ind(self, processed_train_data): tok2ind['_split_'] = cnt cnt += 1 - tok2ind_path = os.path.join(DATASET_PATH, 'tgredial', 'token2id.json') - with open(tok2ind_path, 'w', encoding='utf-8') as write: - json.dump(tok2ind, write, ensure_ascii=False, - indent=4, separators=(',', ':')) - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): From 2a529d4687064662c7f2d201c66821e40b8bff2e Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 13:31:55 +0800 Subject: [PATCH 18/35] txy77 --- config/crs/kgsf/durecdial.yaml | 2 +- config/crs/kgsf/gorecdial.yaml | 2 +- config/crs/kgsf/inspired.yaml | 2 +- config/crs/kgsf/opendialkg.yaml | 2 +- config/crs/kgsf/redial.yaml | 2 +- config/crs/kgsf/tgredial.yaml | 2 +- config/crs/ntrd/tgredial.yaml | 2 +- crslab/data/dataset/base.py | 5 ++-- crslab/data/dataset/durecdial/durecdial.py | 21 ++++++++-------- crslab/data/dataset/gorecdial/gorecdial.py | 21 ++++++++-------- crslab/data/dataset/inspired/inspired.py | 20 ++++++++-------- crslab/data/dataset/opendialkg/opendialkg.py | 25 ++++++++++---------- crslab/data/dataset/redial/redial.py | 21 ++++++++-------- crslab/data/dataset/tgredial/tgredial.py | 21 ++++++++-------- crslab/model/crs/kgsf/kgsf.py | 5 +--- crslab/model/crs/ntrd/ntrd.py | 6 ++--- 16 files changed, 78 insertions(+), 81 deletions(-) diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index 9ad0a9d..bd97fe1 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -1,7 +1,7 @@ # dataset dataset: DuRecDial tokenize: jieba -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index ab00260..8c90308 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -1,7 +1,7 @@ # dataset dataset: GoRecDial tokenize: nltk -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index c3608e5..79f9f22 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -1,7 +1,7 @@ # dataset dataset: Inspired tokenize: nltk -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index 09d47c3..a3ff91e 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -1,7 +1,7 @@ # dataset dataset: OpenDialKG tokenize: nltk -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index 5d11ca1..26b1b11 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -1,7 +1,7 @@ # dataset dataset: ReDial tokenize: nltk -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index 33b2e1a..6718d19 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: pkuseg -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 49c7940..7193b0b 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: pkuseg -embedding: word2vec.npy +embedding: True # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/data/dataset/base.py b/crslab/data/dataset/base.py index 6befff5..080caf7 100644 --- a/crslab/data/dataset/base.py +++ b/crslab/data/dataset/base.py @@ -12,9 +12,8 @@ from abc import ABC, abstractmethod import numpy as np -from loguru import logger - from crslab.download import build +from loguru import logger class BaseDataset(ABC): @@ -52,7 +51,7 @@ def __init__(self, opt, dpath, resource, restore=False, save=False): test_data) embedding = opt.get('embedding', None) if embedding: - self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding)) + self.side_data["embedding"] = self.vocab['word2vec'] logger.debug(f'[Load pretrained embedding {embedding}]') logger.info('[Finish data preprocess]') else: diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index d2a2218..0560daf 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,6 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -122,10 +124,12 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec + wordembedding = None if self.generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') @@ -146,7 +150,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -386,11 +392,7 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'DuRecDial') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) + return copy_mask def generate_word2vec(self, processed_train_data): @@ -412,5 +414,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join(DATASET_PATH, 'durecdial', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index e8e6671..98628ac 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,6 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -123,10 +125,12 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec + wordembedding = None if self.generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') @@ -147,7 +151,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -400,12 +406,6 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'GoRecDial') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) - def generate_word2vec(self, processed_train_data): corpus = [] @@ -429,5 +429,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join(DATASET_PATH, 'gorecdial', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 098a85b..73f4517 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,6 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -123,8 +125,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec + wordembedding = None if self.generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: @@ -147,7 +150,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -408,11 +413,7 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'Inspired') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) + return copy_mask def generate_word2vec(self, processed_train_data): @@ -437,5 +438,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join(DATASET_PATH, 'inspired', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 8078421..00d75e0 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -27,6 +27,7 @@ import os from collections import defaultdict from copy import copy +from http.client import NotConnected import gensim import numpy as np @@ -93,7 +94,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -106,6 +107,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -124,10 +127,12 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - if self. generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = None + if self.generate_embedding: + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') @@ -148,7 +153,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -407,11 +414,7 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'OpenDialKG') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) + return copy_mask def generate_word2vec(self, processed_train_data): @@ -436,6 +439,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join( - DATASET_PATH, 'opendialkg', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 84c6e90..bc8f781 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -94,7 +94,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -107,6 +107,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -125,10 +127,12 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec + wordembedding = None if self.generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') @@ -149,7 +153,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -404,11 +410,7 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'ReDial') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) + return copy_mask def generate_word2vec(self, processed_train_data): @@ -433,5 +435,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join(DATASET_PATH, 'redial', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index b357732..e728980 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -107,7 +107,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.side_data["embedding"]), self.side_data['embedding'][0], axis=0) def _load_data(self): - train_data, valid_data, test_data = self._load_raw_data() + train_data, valid_data, test_data, npy_dict = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -123,6 +123,8 @@ def _load_data(self): 'n_topic': len(self.topic2ind) + 1, 'n_entity': self.n_entity, 'n_word': self.n_word, + 'word2vec': npy_dict['word2vec'], + 'copy_mask': npy_dict['copy_mask'], } vocab.update(self.special_token_idx) @@ -141,10 +143,12 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec + wordembedding = None if self.generate_embedding: - self.generate_word2vec(processed_train_data) + wordembedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') @@ -165,7 +169,9 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - return processed_train_data, processed_valid_data, processed_test_data + npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + + return processed_train_data, processed_valid_data, processed_test_data, npy_dict def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} @@ -493,11 +499,7 @@ def generate_copy_mask(self, tok2ind, processed_train_data): token_id = tok2ind[each_word] copy_mask[token_id] = True - path = os.path.join(MODEL_PATH, 'kgsf', 'TGReDial') - if not os.path.exists(path): - os.makedirs(path) - - np.save(os.path.join(path, 'copy_mask.npy'), copy_mask) + return copy_mask def generate_word2vec(self, processed_train_data): @@ -522,5 +524,4 @@ def generate_word2vec(self, processed_train_data): word2embedding = [[0] * 300] * 4 + [model.wv[word] for word in word2index] - word2vec_path = os.path.join(DATASET_PATH, 'tgredial', 'word2vec.npy') - np.save(word2vec_path, word2embedding) + return word2embedding diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 64b37be..8d45456 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -90,6 +90,7 @@ def __init__(self, opt, device, vocab, side_data): self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) + self.copy_mask = vocab['copy_mask'] # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] @@ -210,8 +211,6 @@ def _build_conversation_layer(self): self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) - self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), - ).to(self.device) self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, @@ -494,8 +493,6 @@ def forward(self, batch, stage, mode): self.entity_edge_type = self.entity_edge_type.cuda( torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) - self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), - ).cuda(torch.cuda.current_device()) if stage == "pretrain": loss = self.pretrain_infomax(batch) elif stage == "rec": diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 23c366a..1a16e80 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -58,6 +58,7 @@ def __init__(self, opt, device, vocab, side_data): self.pretrained_embedding = side_data.get('embedding', None) self.replace_token = opt.get('replace_token', None) self.replace_token_idx = vocab[self.replace_token] + self.copy_mask = vocab['copy_mask'] # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] @@ -185,8 +186,6 @@ def _build_conversation_layer(self): self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) - copy_mask = np.load(os.path.join( - self.dpath, "copy_mask.npy")).astype(bool) if self.replace_token: if self.replace_token_idx < len(copy_mask): copy_mask[self.replace_token_idx] = False @@ -452,8 +451,7 @@ def forward(self, batch, stage, mode): self.entity_edge_type = self.entity_edge_type.cuda( torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) - self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), - ).cuda(torch.cuda.current_device()) + if stage == "pretrain": loss = self.pretrain_infomax(batch) elif stage == "rec": From 4117166db0034f1a9f4573d2720ae3f34692c04b Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 13:41:11 +0800 Subject: [PATCH 19/35] txy77 --- crslab/data/dataset/redial/redial.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index bc8f781..6b7836e 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -88,8 +88,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.special_token_idx = token['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.tokenize_class = CRS_Tokenizer - self.Tokenizer = self.tokenize_class(self.tokenize_path) + self.Tokenizer = CRS_Tokenizer dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) From 27075526ef25ee4599e353c1b677c261dd4639f8 Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 19:48:55 +0800 Subject: [PATCH 20/35] txy77 --- crslab/model/crs/kgsf/kgsf.py | 2 +- crslab/model/crs/ntrd/ntrd.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 8d45456..f6fdb77 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -90,7 +90,7 @@ def __init__(self, opt, device, vocab, side_data): self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) - self.copy_mask = vocab['copy_mask'] + self.copy_mask = torch.as_tensor(vocab['copy_mask'].astype(bool)).to(self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 1a16e80..dc980b7 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -58,7 +58,7 @@ def __init__(self, opt, device, vocab, side_data): self.pretrained_embedding = side_data.get('embedding', None) self.replace_token = opt.get('replace_token', None) self.replace_token_idx = vocab[self.replace_token] - self.copy_mask = vocab['copy_mask'] + self.copy_mask = torch.as_tensor(vocab['copy_mask'].astype(bool)).to(self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] From f92d4859b6a62313629b1582f9e32f0537f216a6 Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 20:31:11 +0800 Subject: [PATCH 21/35] txy77 --- crslab/data/dataloader/inspired.py | 30 +++++++------- crslab/data/dataloader/kbrd.py | 14 +++---- crslab/data/dataloader/kgsf.py | 16 ++++---- crslab/data/dataloader/ntrd.py | 19 ++++----- crslab/data/dataloader/redial.py | 13 +++--- crslab/data/dataloader/tgredial.py | 40 +++++++++---------- crslab/data/dataset/durecdial/durecdial.py | 7 ++-- crslab/data/dataset/durecdial/resources.py | 38 ------------------ crslab/data/dataset/gorecdial/gorecdial.py | 5 +-- crslab/data/dataset/gorecdial/resources.py | 37 ----------------- crslab/data/dataset/inspired/inspired.py | 5 +-- crslab/data/dataset/inspired/resources.py | 34 ---------------- crslab/data/dataset/opendialkg/opendialkg.py | 5 +-- crslab/data/dataset/opendialkg/resources.py | 34 ---------------- crslab/data/dataset/redial/redial.py | 5 +-- crslab/data/dataset/redial/resources.py | 34 ---------------- crslab/data/dataset/tgredial/resources.py | 39 ------------------ crslab/data/dataset/tgredial/tgredial.py | 5 +-- crslab/data/dataset/tokenizer/bert.py | 13 +++++- crslab/data/dataset/tokenizer/gpt2.py | 15 ++++++- crslab/data/dataset/tokenizer/jieba.py | 8 ++++ crslab/data/dataset/tokenizer/nltk.py | 11 ++++- crslab/data/dataset/tokenizer/pkuseg.py | 9 +++++ crslab/model/conversation/gpt2/gpt2.py | 2 +- .../conversation/transformer/transformer.py | 18 ++++----- crslab/model/crs/inspired/inspired_conv.py | 2 +- crslab/model/crs/kbrd/kbrd.py | 6 +-- crslab/model/crs/kgsf/kgsf.py | 10 ++--- crslab/model/crs/ntrd/ntrd.py | 10 ++--- crslab/model/crs/redial/redial_conv.py | 10 ++--- crslab/model/crs/redial/redial_rec.py | 3 +- crslab/model/crs/tgredial/tg_conv.py | 2 +- crslab/model/policy/pmi/pmi.py | 3 +- crslab/system/kbrd.py | 5 +-- crslab/system/kgsf.py | 2 +- crslab/system/ntrd.py | 11 +++-- 36 files changed, 172 insertions(+), 348 deletions(-) diff --git a/crslab/data/dataloader/inspired.py b/crslab/data/dataloader/inspired.py index 3881113..a5983e6 100644 --- a/crslab/data/dataloader/inspired.py +++ b/crslab/data/dataloader/inspired.py @@ -5,10 +5,10 @@ from copy import deepcopy import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt +from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, + padded_tensor, truncate) +from tqdm import tqdm class InspiredDataLoader(BaseDataLoader): @@ -56,20 +56,20 @@ def __init__(self, opt, dataset, vocab): super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.unk_token_idx = vocab['unk'] - self.conv_bos_id = vocab['start'] - self.cls_id = vocab['start'] - self.sep_id = vocab['end'] - if 'sent_split' in vocab: - self.sent_split_idx = vocab['sent_split'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.unk_token_idx = vocab['special_token_idx']['unk'] + self.conv_bos_id = vocab['special_token_idx']['start'] + self.cls_id = vocab['special_token_idx']['start'] + self.sep_id = vocab['special_token_idx']['end'] + if 'sent_split' in vocab['special_token_idx']: + self.sent_split_idx = vocab['special_token_idx']['sent_split'] else: - self.sent_split_idx = vocab['end'] + self.sent_split_idx = vocab['special_token_idx']['end'] - self.pad_entity_idx = vocab['pad_entity'] - self.pad_word_idx = vocab['pad_word'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] self.tok2ind = vocab['tok2ind'] self.ind2tok = vocab['ind2tok'] diff --git a/crslab/data/dataloader/kbrd.py b/crslab/data/dataloader/kbrd.py index 2720c5d..ee90381 100644 --- a/crslab/data/dataloader/kbrd.py +++ b/crslab/data/dataloader/kbrd.py @@ -8,10 +8,10 @@ # @Email : wxl1999@foxmail.com import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt +from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, + padded_tensor, truncate) +from tqdm import tqdm class KBRDDataLoader(BaseDataLoader): @@ -45,10 +45,10 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.pad_entity_idx = vocab['pad_entity'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) diff --git a/crslab/data/dataloader/kgsf.py b/crslab/data/dataloader/kgsf.py index 6bbcac4..c43e933 100644 --- a/crslab/data/dataloader/kgsf.py +++ b/crslab/data/dataloader/kgsf.py @@ -10,10 +10,10 @@ from copy import deepcopy import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, get_onehot, truncate, merge_utt +from crslab.data.dataloader.utils import (add_start_end_token_idx, get_onehot, + merge_utt, padded_tensor, truncate) +from tqdm import tqdm class KGSFDataLoader(BaseDataLoader): @@ -52,11 +52,11 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.pad_entity_idx = vocab['pad_entity'] - self.pad_word_idx = vocab['pad_word'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) diff --git a/crslab/data/dataloader/ntrd.py b/crslab/data/dataloader/ntrd.py index bbf1e80..603be05 100644 --- a/crslab/data/dataloader/ntrd.py +++ b/crslab/data/dataloader/ntrd.py @@ -5,10 +5,11 @@ from copy import deepcopy import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import add_start_end_token_idx, merge_utt_replace, padded_tensor, get_onehot, truncate, merge_utt +from crslab.data.dataloader.utils import (add_start_end_token_idx, get_onehot, + merge_utt, merge_utt_replace, + padded_tensor, truncate) +from tqdm import tqdm class NTRDDataLoader(BaseDataLoader): @@ -23,11 +24,11 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.pad_entity_idx = vocab['pad_entity'] - self.pad_word_idx = vocab['pad_word'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) @@ -113,4 +114,4 @@ def conv_batchify(self, batch): padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False)) def policy_batchify(self, *args, **kwargs): - pass \ No newline at end of file + pass diff --git a/crslab/data/dataloader/redial.py b/crslab/data/dataloader/redial.py index 6cd1289..84c47b6 100644 --- a/crslab/data/dataloader/redial.py +++ b/crslab/data/dataloader/redial.py @@ -11,10 +11,9 @@ from copy import copy import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import padded_tensor, get_onehot, truncate +from crslab.data.dataloader.utils import get_onehot, padded_tensor, truncate +from tqdm import tqdm movie_pattern = re.compile(r'^@\d{5,6}$') @@ -55,10 +54,10 @@ def __init__(self, opt, dataset, vocab): super().__init__(opt, dataset) self.ind2tok = vocab['ind2tok'] self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.unk_token_idx = vocab['unk'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.unk_token_idx = vocab['special_token_idx']['unk'] self.item_token_idx = vocab['vocab_size'] self.conversation_truncate = self.opt.get('conversation_truncate', None) self.utterance_truncate = self.opt.get('utterance_truncate', None) diff --git a/crslab/data/dataloader/tgredial.py b/crslab/data/dataloader/tgredial.py index bef2ca7..e8457ec 100644 --- a/crslab/data/dataloader/tgredial.py +++ b/crslab/data/dataloader/tgredial.py @@ -11,10 +11,10 @@ from copy import deepcopy import torch -from tqdm import tqdm - from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt +from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, + padded_tensor, truncate) +from tqdm import tqdm class TGReDialDataLoader(BaseDataLoader): @@ -65,26 +65,26 @@ def __init__(self, opt, dataset, vocab): self.n_entity = vocab['n_entity'] self.item_size = self.n_entity - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.unk_token_idx = vocab['unk'] - self.conv_bos_id = vocab['start'] - self.cls_id = vocab['start'] - self.sep_id = vocab['end'] - if 'sent_split' in vocab: - self.sent_split_idx = vocab['sent_split'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.unk_token_idx = vocab['special_token_idx']['unk'] + self.conv_bos_id = vocab['special_token_idx']['start'] + self.cls_id = vocab['special_token_idx']['start'] + self.sep_id = vocab['special_token_idx']['end'] + if 'sent_split' in vocab['special_token_idx']: + self.sent_split_idx = vocab['special_token_idx']['sent_split'] else: - self.sent_split_idx = vocab['end'] - if 'word_split' in vocab: - self.word_split_idx = vocab['word_split'] + self.sent_split_idx = vocab['special_token_idx']['end'] + if 'word_split' in vocab['special_token_idx']: + self.word_split_idx = vocab['special_token_idx']['word_split'] else: - self.word_split_idx = vocab['end'] + self.word_split_idx = vocab['special_token_idx']['end'] - self.pad_entity_idx = vocab['pad_entity'] - self.pad_word_idx = vocab['pad_word'] - if 'pad_topic' in vocab: - self.pad_topic_idx = vocab['pad_topic'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] + if 'pad_topic' in vocab['special_token_idx']: + self.pad_topic_idx = vocab['special_token_idx']['pad_topic'] self.tok2ind = vocab['tok2ind'] self.ind2tok = vocab['ind2tok'] diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 0560daf..7d4b98a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -83,8 +83,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize self.Tokenizer = CRS_Tokenizer @@ -107,9 +106,9 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx, } - vocab.update(self.special_token_idx) - + return train_data, valid_data, test_data, vocab def _load_raw_data(self): diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index 9cf4f27..c226269 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -23,43 +23,5 @@ 'durecdial.zip', '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', ), - 'jieba': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } }, } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 98628ac..2cd7740 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -83,8 +83,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize self.Tokenizer = CRS_Tokenizer @@ -107,8 +106,8 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx, } - vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index e286ba2..57c8614 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -23,42 +23,5 @@ 'gorecdial.zip', '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', ), - 'nltk': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - }, }, - } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 73f4517..634dbb3 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -83,8 +83,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize self.Tokenizer = CRS_Tokenizer @@ -107,8 +106,8 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx, } - vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index 77f915e..38fb3be 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -23,39 +23,5 @@ 'inspired.zip', '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', ), - 'nltk': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 00d75e0..2fe385c 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -85,8 +85,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize self.Tokenizer = CRS_Tokenizer @@ -109,8 +108,8 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx } - vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index 9ff0eed..2fc13db 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -23,39 +23,5 @@ 'opendialkg.zip', '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', ), - 'nltk': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } }, } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 6b7836e..fe8993f 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -84,8 +84,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize self.Tokenizer = CRS_Tokenizer @@ -108,8 +107,8 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx } - vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index 9cd1e27..9809a6e 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -23,39 +23,5 @@ 'redial.zip', '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', ), - 'nltk': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0 - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - }, - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0 - }, - } }, } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 92506f7..721afae 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -23,44 +23,5 @@ 'tgredial.zip', '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', ), - 'pkuseg':{ - 'special_token_idx': { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'bert': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - }, - }, - 'gpt2': { - 'special_token_idx': { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'cls': 101, - 'sep': 102, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - }, - } }, } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index e728980..51546e5 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -87,8 +87,7 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - token = resource[tokenize] - self.special_token_idx = token['special_token_idx'] + self.special_token_idx = CRS_Tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] @@ -125,8 +124,8 @@ def _load_data(self): 'n_word': self.n_word, 'word2vec': npy_dict['word2vec'], 'copy_mask': npy_dict['copy_mask'], + 'special_token_idx': self.special_token_idx } - vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab diff --git a/crslab/data/dataset/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py index fc74bcb..cc2a526 100644 --- a/crslab/data/dataset/tokenizer/bert.py +++ b/crslab/data/dataset/tokenizer/bert.py @@ -9,8 +9,19 @@ class bert_tokenize(BaseCrsTokenize): def __init__(self, path=None) -> None: - super().__init__(path) + self.special_token_idx = { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } self.my_tokenizer = AutoTokenizer.from_pretrained(path) + super().__init__(path) def tokenize(self, text): return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py index 2ddb0c1..142d2c5 100644 --- a/crslab/data/dataset/tokenizer/gpt2.py +++ b/crslab/data/dataset/tokenizer/gpt2.py @@ -9,8 +9,21 @@ class gpt2_tokenize(BaseCrsTokenize): def __init__(self, path=None) -> None: - super().__init__(path) + self.special_token_idx = { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'cls': 101, + 'sep': 102, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + } self.my_tokenizer = AutoTokenizer.from_pretrained(path) + super().__init__(path) def tokenize(self, text): return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py index 753f0ba..d76038f 100644 --- a/crslab/data/dataset/tokenizer/jieba.py +++ b/crslab/data/dataset/tokenizer/jieba.py @@ -10,6 +10,14 @@ class jieba_tokenize(BaseCrsTokenize): def __init__(self, path=None) -> None: + self.special_token_idx = { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + } super().__init__(path) def tokenize(self, text): diff --git a/crslab/data/dataset/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py index 4e016f1..48757d7 100644 --- a/crslab/data/dataset/tokenizer/nltk.py +++ b/crslab/data/dataset/tokenizer/nltk.py @@ -4,13 +4,22 @@ from crslab.data.dataset.tokenizer.base import BaseCrsTokenize -from nltk import word_tokenize import nltk +from nltk import word_tokenize class nltk_tokenize(BaseCrsTokenize): def __init__(self, path=None) -> None: + self.special_token_idx = { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + } super().__init__(path) def tokenize(self, text): diff --git a/crslab/data/dataset/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py index fabdcad..695266f 100644 --- a/crslab/data/dataset/tokenizer/pkuseg.py +++ b/crslab/data/dataset/tokenizer/pkuseg.py @@ -11,6 +11,15 @@ class pkuseg_tokenize(BaseCrsTokenize): def __init__(self, path=None) -> None: self.pkuseg_tokenizer = pkuseg.pkuseg() + self.special_token_idx = { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } super().__init__(path) def tokenize(self, text): diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index 2f7a345..94dc405 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -52,7 +52,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['pad'] + self.pad_id = vocab['special_token_idx']['pad'] self.dpath = opt['conv_pretrained_path'] super(GPT2Model, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/conversation/transformer/transformer.py b/crslab/model/conversation/transformer/transformer.py index e36260e..e0f7a8b 100644 --- a/crslab/model/conversation/transformer/transformer.py +++ b/crslab/model/conversation/transformer/transformer.py @@ -20,12 +20,12 @@ import torch import torch.nn.functional as F -from loguru import logger -from torch import nn - from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format -from crslab.model.utils.modules.transformer import TransformerEncoder, TransformerDecoder +from crslab.model.utils.modules.transformer import (TransformerDecoder, + TransformerEncoder) +from loguru import logger +from torch import nn class TransformerModel(BaseModel): @@ -70,16 +70,16 @@ def __init__(self, opt, device, vocab, side_data): """ # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrain_embedding = side_data.get('embedding', None) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['pad_word'] - self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index af91a61..10f79f6 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -38,7 +38,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['pad'] + self.pad_id = vocab['special_token_idx']['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 self.dpath = opt['conv_pretrained_path'] diff --git a/crslab/model/crs/kbrd/kbrd.py b/crslab/model/crs/kbrd/kbrd.py index 81b6c4c..0e92e18 100644 --- a/crslab/model/crs/kbrd/kbrd.py +++ b/crslab/model/crs/kbrd/kbrd.py @@ -74,9 +74,9 @@ def __init__(self, opt, device, vocab, side_data): self.device = device self.gpu = opt.get("gpu", [-1]) # vocab - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] self.vocab_size = vocab['vocab_size'] self.token_emb_dim = opt.get('token_emb_dim', 300) self.pretrain_embedding = side_data.get('embedding', None) diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index f6fdb77..65e7559 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -85,17 +85,17 @@ def __init__(self, opt, device, vocab, side_data): self.gpu = opt.get("gpu", [-1]) # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) self.copy_mask = torch.as_tensor(vocab['copy_mask'].astype(bool)).to(self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['pad_word'] - self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index dc980b7..f405c48 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -51,9 +51,9 @@ def __init__(self, opt, device, vocab, side_data): self.gpu = opt.get("gpu", [-1]) # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) self.replace_token = opt.get('replace_token', None) @@ -62,8 +62,8 @@ def __init__(self, opt, device, vocab, side_data): # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['pad_word'] - self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] diff --git a/crslab/model/crs/redial/redial_conv.py b/crslab/model/crs/redial/redial_conv.py index 8e062fb..b7a9529 100644 --- a/crslab/model/crs/redial/redial_conv.py +++ b/crslab/model/crs/redial/redial_conv.py @@ -19,9 +19,9 @@ """ import torch +from crslab.model.base import BaseModel from torch import nn -from crslab.model.base import BaseModel from .modules import HRNN, SwitchingDecoder @@ -59,10 +59,10 @@ def __init__(self, opt, device, vocab, side_data): """ # dataset self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['pad'] - self.start_token_idx = vocab['start'] - self.end_token_idx = vocab['end'] - self.unk_token_idx = vocab['unk'] + self.pad_token_idx = vocab['special_token_idx']['pad'] + self.start_token_idx = vocab['special_token_idx']['start'] + self.end_token_idx = vocab['special_token_idx']['end'] + self.unk_token_idx = vocab['special_token_idx']['unk'] self.pretrained_embedding = side_data.get('embedding', None) self.embedding_dim = opt.get('embedding_dim', None) if opt.get('embedding', None) and self.embedding_dim is None: diff --git a/crslab/model/crs/redial/redial_rec.py b/crslab/model/crs/redial/redial_rec.py index 4bbc289..3a0ad46 100644 --- a/crslab/model/crs/redial/redial_rec.py +++ b/crslab/model/crs/redial/redial_rec.py @@ -19,7 +19,6 @@ """ import torch.nn as nn - from crslab.model.base import BaseModel @@ -45,7 +44,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.n_entity = vocab['n_entity'] self.layer_sizes = opt['autorec_layer_sizes'] - self.pad_entity_idx = vocab['pad_entity'] + self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] super(ReDialRecModel, self).__init__(opt, device) diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 3995b52..0fc85f4 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -52,7 +52,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['pad'] + self.pad_id = vocab['special_token_idx']['pad'] self.dpath = opt['conv_pretrained_path'] super(TGConvModel, self).__init__(opt, device, self.dpath) diff --git a/crslab/model/policy/pmi/pmi.py b/crslab/model/policy/pmi/pmi.py index a406cb3..61e198d 100644 --- a/crslab/model/policy/pmi/pmi.py +++ b/crslab/model/policy/pmi/pmi.py @@ -15,7 +15,6 @@ from collections import defaultdict import torch - from crslab.model.base import BaseModel @@ -39,7 +38,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - self.pad_topic = vocab['pad_topic'] + self.pad_topic = vocab['special_token_idx']['pad_topic'] super(PMIModel, self).__init__(opt, device) def build_model(self, *args, **kwargs): diff --git a/crslab/system/kbrd.py b/crslab/system/kbrd.py index 085eda4..5579da8 100644 --- a/crslab/system/kbrd.py +++ b/crslab/system/kbrd.py @@ -11,12 +11,11 @@ import os import torch -from loguru import logger - from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt +from loguru import logger class KBRDSystem(BaseSystem): @@ -43,7 +42,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc restore_system, interact, debug, tensorboard) self.ind2tok = vocab['ind2tok'] - self.end_token_idx = vocab['end'] + self.end_token_idx = vocab['special_token_idx']['end'] self.item_ids = side_data['item_entity_ids'] self.rec_optim_opt = opt['rec'] diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index f5e2cf7..047ff4f 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -46,7 +46,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc restore_system, interact, debug, tensorboard) self.ind2tok = vocab['ind2tok'] - self.end_token_idx = vocab['end'] + self.end_token_idx = vocab['special_token_idx']['end'] self.item_ids = side_data['item_entity_ids'] self.pretrain_optim_opt = self.opt['pretrain'] diff --git a/crslab/system/ntrd.py b/crslab/system/ntrd.py index ab1e923..ffbc1dc 100644 --- a/crslab/system/ntrd.py +++ b/crslab/system/ntrd.py @@ -3,16 +3,15 @@ # @Email : oran_official@outlook.com import os -from crslab.evaluator.metrics import gen -from numpy.core.numeric import NaN import torch -from loguru import logger - +from crslab.evaluator.metrics import gen from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem -from crslab.system.utils.functions import ind2slot,ind2txt_with_slots +from crslab.system.utils.functions import ind2slot, ind2txt_with_slots +from loguru import logger +from numpy.core.numeric import NaN class NTRDSystem(BaseSystem): @@ -25,7 +24,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.ind2tok = vocab['ind2tok'] self.ind2movie = vocab['id2entity'] - self.end_token_idx = vocab['end'] + self.end_token_idx = vocab['special_token_idx']['end'] self.item_ids = side_data['item_entity_ids'] self.pretrain_optim_opt = self.opt['pretrain'] From 4763e577da6dc51bb9abcbce4c601c38f8188640 Mon Sep 17 00:00:00 2001 From: txy77 Date: Wed, 12 Oct 2022 21:28:39 +0800 Subject: [PATCH 22/35] txy77 --- crslab/system/inspired.py | 2 +- crslab/system/redial.py | 2 +- crslab/system/tgredial.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crslab/system/inspired.py b/crslab/system/inspired.py index 5a5f9f1..e827ee3 100644 --- a/crslab/system/inspired.py +++ b/crslab/system/inspired.py @@ -38,7 +38,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc if hasattr(self, 'conv_model'): self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['end'] + self.end_token_idx = vocab['conv']['special_token_idx']['end'] if hasattr(self, 'rec_model'): self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] diff --git a/crslab/system/redial.py b/crslab/system/redial.py index 967cf14..2ce257e 100644 --- a/crslab/system/redial.py +++ b/crslab/system/redial.py @@ -38,7 +38,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc super(ReDialSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['end'] + self.end_token_idx = vocab['conv']['special_token_idx']['end'] self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 4149b2d..9d5d390 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -51,7 +51,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc if hasattr(self, 'conv_model'): self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['end'] + self.end_token_idx = vocab['conv']['special_token_idx']['end'] if hasattr(self, 'rec_model'): self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] From b6d4c346d9687ad15d25def7f64477f93873e0cc Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 20:57:15 +0800 Subject: [PATCH 23/35] txy77 --- crslab/data/__init__.py | 4 ++-- crslab/data/dataset/durecdial/durecdial.py | 11 ++++++----- crslab/data/dataset/gorecdial/gorecdial.py | 11 ++++++----- crslab/data/dataset/inspired/inspired.py | 11 ++++++----- crslab/data/dataset/opendialkg/opendialkg.py | 11 ++++++----- crslab/data/dataset/redial/redial.py | 11 ++++++----- crslab/data/dataset/tgredial/tgredial.py | 11 ++++++----- crslab/quick_start/quick_start.py | 8 ++++---- 8 files changed, 42 insertions(+), 36 deletions(-) diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 31272ab..0e299a4 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -86,7 +86,7 @@ def get_tokenizer(tokenize, path=None) -> BaseCrsTokenize: return tokenizer_register_table[tokenize](path) -def get_dataset(opt, tokenize, CRS_Tokenizer, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, crs_tokenizer, restore, save) -> BaseDataset: """get and process dataset Args: @@ -101,7 +101,7 @@ def get_dataset(opt, tokenize, CRS_Tokenizer, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, CRS_Tokenizer, restore, save) + return dataset_register_table[dataset](opt, tokenize, crs_tokenizer, restore, save) else: raise NotImplementedError( f'The dataloader [{dataset}] has not been implemented') diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 7d4b98a..866c7b4 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -29,11 +29,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -62,7 +63,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """ Args: @@ -83,10 +84,10 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 2cd7740..a87fbbb 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -29,11 +29,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -62,7 +63,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -83,10 +84,10 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 634dbb3..746edee 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -29,11 +29,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -62,7 +63,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -83,10 +84,10 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 2fe385c..5b2068e 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -31,11 +31,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -64,7 +65,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -85,10 +86,10 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index fe8993f..591f37f 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -30,11 +30,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -63,7 +64,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -84,10 +85,10 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 51546e5..1dc134d 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -30,11 +30,12 @@ import gensim import numpy as np -from crslab.config import DATASET_PATH, MODEL_PATH -from crslab.data.dataset.base import BaseDataset from loguru import logger from tqdm import tqdm +from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.data.dataset.base import BaseDataset + from .resources import resources @@ -66,7 +67,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -87,12 +88,12 @@ def __init__(self, opt, tokenize, CRS_Tokenizer, restore=False, save=False): self.generate_embedding = False resource = resources['resource'] - self.special_token_idx = CRS_Tokenizer.special_token_idx + self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] self.tokenize = tokenize - self.Tokenizer = CRS_Tokenizer + self.Tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'tgredial') self.replace_token = opt.get('replace_token', None) diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 62a1cfc..b37f63b 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -33,9 +33,9 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_Tokenizer = get_tokenizer(config['tokenize'], path=None) + CRS_tokenizer = get_tokenizer(config['tokenize'], path=None) CRS_dataset = get_dataset( - config, config['tokenize'], CRS_Tokenizer, restore_data, save_data) + config, config['tokenize'], CRS_tokenizer, restore_data, save_data) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab @@ -60,9 +60,9 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r tokenize_path = None if task_tokenize_path in config: tokenize_path = config[task_tokenize_path] - CRS_Tokenizer = get_tokenizer(tokenize, tokenize_path) + CRS_tokenizer = get_tokenizer(tokenize, tokenize_path) dataset = get_dataset( - config, tokenize, CRS_Tokenizer, restore_data, save_data) + config, tokenize, CRS_tokenizer, restore_data, save_data) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data From b6445df0e5dbe15660a77cfefcdab706d015db44 Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 20:59:13 +0800 Subject: [PATCH 24/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 6 +++--- crslab/data/dataset/gorecdial/gorecdial.py | 6 +++--- crslab/data/dataset/inspired/inspired.py | 6 +++--- crslab/data/dataset/opendialkg/opendialkg.py | 6 +++--- crslab/data/dataset/redial/redial.py | 6 +++--- crslab/data/dataset/tgredial/tgredial.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 866c7b4..713ef7a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -124,9 +124,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask copy_mask = None @@ -150,7 +150,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index a87fbbb..7dc1053 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -125,9 +125,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask copy_mask = None @@ -151,7 +151,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 746edee..0775391 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -125,9 +125,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask if self.copy: @@ -150,7 +150,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 5b2068e..07179e1 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -127,9 +127,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask copy_mask = None @@ -153,7 +153,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 591f37f..4b19dee 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -126,9 +126,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask copy_mask = None @@ -152,7 +152,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 1dc134d..1c35cc9 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -143,9 +143,9 @@ def _load_raw_data(self): self.tok2ind = self.generate_tok2ind(processed_train_data) logger.info("[Finish generate train tok2ind]") # generate word2vec - wordembedding = None + word_embedding = None if self.generate_embedding: - wordembedding = self.generate_word2vec(processed_train_data) + word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask copy_mask = None @@ -169,7 +169,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': wordembedding, 'copy_mask': copy_mask} + npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} return processed_train_data, processed_valid_data, processed_test_data, npy_dict From d7b9d5b103ea18c5d9d04d3090c5a311bb70b47a Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 21:10:36 +0800 Subject: [PATCH 25/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 10 ++++------ crslab/data/dataset/gorecdial/gorecdial.py | 11 ++++------- crslab/data/dataset/inspired/inspired.py | 15 ++++++--------- crslab/data/dataset/opendialkg/opendialkg.py | 11 ++++------- crslab/data/dataset/redial/redial.py | 11 ++++------- crslab/data/dataset/tgredial/tgredial.py | 11 ++++------- 6 files changed, 26 insertions(+), 43 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 713ef7a..00335b3 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -331,8 +331,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -371,20 +370,19 @@ def generate_tok2ind(self, processed_train_data): return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for item in dialog['item']: - list_word = crstokenize.tokenize(item) + list_word = self.Tokenizer.tokenize(item) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) for each_word in text: diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 7dc1053..cac6b97 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -337,8 +337,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -381,22 +380,20 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer - copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = crstokenize.tokenize(movie) + list_word = self.Tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 0775391..e61919b 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -336,8 +336,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -380,30 +379,28 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer - copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = crstokenize.tokenize(movie) + list_word = self.Tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word for genre in dialog['genre']: - list_word = crstokenize.tokenize(genre) + list_word = self.Tokenizer.tokenize(genre) match_list += list_word for people in dialog['people']: - list_word = crstokenize.tokenize(people) + list_word = self.Tokenizer.tokenize(people) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 07179e1..46a04c8 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -345,8 +345,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -389,22 +388,20 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer - copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word for item in dialog['item']: - list_word = crstokenize.tokenize(item) + list_word = self.Tokenizer.tokenize(item) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 4b19dee..aa8b2c0 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -341,8 +341,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -385,21 +384,19 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer - copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['dialog']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = crstokenize.tokenize(movie) + list_word = self.Tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 1c35cc9..c8f2f2f 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -428,8 +428,7 @@ def split_text(self, data): each_dict['conv_id'] = each['conv_id'] for one in each['messages']: str_text = one['text'] - crstokenize = self.Tokenizer - list_text = crstokenize.tokenize(str_text) + list_text = self.Tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['messages'] = each_data @@ -473,23 +472,21 @@ def generate_tok2ind(self, processed_train_data): def generate_copy_mask(self, tok2ind, processed_train_data): - crstokenize = self.Tokenizer - copy_mask = np.zeros((len(tok2ind)), dtype=bool) for each_data in tqdm(processed_train_data): for dialog in each_data['messages']: match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = crstokenize.tokenize(word) + list_word = self.Tokenizer.tokenize(word) match_list += list_word for movie in dialog['movie']: - list_word = crstokenize.tokenize(movie) + list_word = self.Tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = crstokenize.tokenize(entity) + list_word = self.Tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) From 2cab622c665c437575784c31166eab365d5a911e Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 21:13:55 +0800 Subject: [PATCH 26/35] txy77 --- crslab/data/__init__.py | 2 +- crslab/data/dataset/tokenizer/__init__.py | 2 +- crslab/data/dataset/tokenizer/base.py | 2 +- crslab/data/dataset/tokenizer/bert.py | 5 +++-- crslab/data/dataset/tokenizer/gpt2.py | 5 +++-- crslab/data/dataset/tokenizer/jieba.py | 6 +++--- crslab/data/dataset/tokenizer/nltk.py | 6 +++--- crslab/data/dataset/tokenizer/pkuseg.py | 6 +++--- 8 files changed, 18 insertions(+), 16 deletions(-) diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 0e299a4..0c4200e 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -79,7 +79,7 @@ } -def get_tokenizer(tokenize, path=None) -> BaseCrsTokenize: +def get_tokenizer(tokenize, path=None) -> BaseTokenizer: """ get tokenizer from opt """ diff --git a/crslab/data/dataset/tokenizer/__init__.py b/crslab/data/dataset/tokenizer/__init__.py index 3b7ef7a..bacd3aa 100644 --- a/crslab/data/dataset/tokenizer/__init__.py +++ b/crslab/data/dataset/tokenizer/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseCrsTokenize +from .base import BaseTokenizer from .bert import bert_tokenize from .gpt2 import gpt2_tokenize from .jieba import jieba_tokenize diff --git a/crslab/data/dataset/tokenizer/base.py b/crslab/data/dataset/tokenizer/base.py index d9ccd22..b5966bb 100644 --- a/crslab/data/dataset/tokenizer/base.py +++ b/crslab/data/dataset/tokenizer/base.py @@ -2,7 +2,7 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -class BaseCrsTokenize: +class BaseTokenizer: def __init__(self, path=None) -> None: pass diff --git a/crslab/data/dataset/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py index cc2a526..97d97ed 100644 --- a/crslab/data/dataset/tokenizer/bert.py +++ b/crslab/data/dataset/tokenizer/bert.py @@ -2,11 +2,12 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize from transformers import AutoTokenizer +from crslab.data.dataset.tokenizer.base import BaseTokenizer -class bert_tokenize(BaseCrsTokenize): + +class bert_tokenize(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py index 142d2c5..081d2ca 100644 --- a/crslab/data/dataset/tokenizer/gpt2.py +++ b/crslab/data/dataset/tokenizer/gpt2.py @@ -2,11 +2,12 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize from transformers import AutoTokenizer +from crslab.data.dataset.tokenizer.base import BaseTokenizer -class gpt2_tokenize(BaseCrsTokenize): + +class gpt2_tokenize(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py index d76038f..6755e43 100644 --- a/crslab/data/dataset/tokenizer/jieba.py +++ b/crslab/data/dataset/tokenizer/jieba.py @@ -2,12 +2,12 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize - import jieba +from crslab.data.dataset.tokenizer.base import BaseTokenizer + -class jieba_tokenize(BaseCrsTokenize): +class jieba_tokenize(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py index 48757d7..a5419c1 100644 --- a/crslab/data/dataset/tokenizer/nltk.py +++ b/crslab/data/dataset/tokenizer/nltk.py @@ -2,13 +2,13 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize - import nltk from nltk import word_tokenize +from crslab.data.dataset.tokenizer.base import BaseTokenizer + -class nltk_tokenize(BaseCrsTokenize): +class nltk_tokenize(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py index 695266f..9ab358f 100644 --- a/crslab/data/dataset/tokenizer/pkuseg.py +++ b/crslab/data/dataset/tokenizer/pkuseg.py @@ -2,12 +2,12 @@ # @Author : Xinyu Tang # @Email : txy20010310@163.com -from crslab.data.dataset.tokenizer.base import BaseCrsTokenize - import pkuseg +from crslab.data.dataset.tokenizer.base import BaseTokenizer + -class pkuseg_tokenize(BaseCrsTokenize): +class pkuseg_tokenize(BaseTokenizer): def __init__(self, path=None) -> None: self.pkuseg_tokenizer = pkuseg.pkuseg() From 97ea317fd6f03b805dbd90511701d3f837e029c1 Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 21:21:03 +0800 Subject: [PATCH 27/35] txy77 --- crslab/model/crs/kgsf/kgsf.py | 9 +++++---- crslab/model/crs/ntrd/ntrd.py | 11 ++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index 65e7559..a428aaf 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -28,14 +28,15 @@ import numpy as np import torch import torch.nn.functional as F +from loguru import logger +from torch import nn +from torch_geometric.nn import GCNConv, RGCNConv + from crslab.config import MODEL_PATH from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder -from loguru import logger -from torch import nn -from torch_geometric.nn import GCNConv, RGCNConv from .modules import GateLayer, TransformerDecoderKG @@ -90,7 +91,7 @@ def __init__(self, opt, device, vocab, side_data): self.end_token_idx = vocab['special_token_idx']['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) - self.copy_mask = torch.as_tensor(vocab['copy_mask'].astype(bool)).to(self.device) + self.copy_mask = torch.as_tensor(vocab['copy_mask'], dtype=torch.bool, device=self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index f405c48..010dfd6 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -23,14 +23,15 @@ import numpy as np import torch import torch.nn.functional as F +from loguru import logger +from torch import nn +from torch_geometric.nn import GCNConv, RGCNConv + from crslab.config import MODEL_PATH from crslab.model.base import BaseModel from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder -from loguru import logger -from torch import nn -from torch_geometric.nn import GCNConv, RGCNConv from .modules import (GateLayer, TransformerDecoderKG, TransformerDecoderSelection) @@ -58,7 +59,7 @@ def __init__(self, opt, device, vocab, side_data): self.pretrained_embedding = side_data.get('embedding', None) self.replace_token = opt.get('replace_token', None) self.replace_token_idx = vocab[self.replace_token] - self.copy_mask = torch.as_tensor(vocab['copy_mask'].astype(bool)).to(self.device) + self.copy_mask = torch.as_tensor(vocab['copy_mask'], dtype=torch.bool, device=self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] @@ -191,7 +192,7 @@ def _build_conversation_layer(self): copy_mask[self.replace_token_idx] = False else: copy_mask = np.insert(copy_mask, len(copy_mask), False) - self.copy_mask = torch.as_tensor(copy_mask).to(self.device) + self.copy_mask = torch.as_tensor(copy_mask, device=self.device) self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, From af93741db3c1c0a7d5becb2a6d2a7e10730b7c0d Mon Sep 17 00:00:00 2001 From: txy77 Date: Fri, 14 Oct 2022 21:31:59 +0800 Subject: [PATCH 28/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 10 ++++------ crslab/data/dataset/gorecdial/gorecdial.py | 10 ++++------ crslab/data/dataset/inspired/inspired.py | 10 ++++------ crslab/data/dataset/opendialkg/opendialkg.py | 10 ++++------ crslab/data/dataset/redial/redial.py | 10 ++++------ crslab/data/dataset/tgredial/tgredial.py | 10 ++++------ 6 files changed, 24 insertions(+), 36 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 00335b3..8f78032 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,8 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx, } @@ -150,9 +150,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index cac6b97..e0f9aa4 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,8 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx, } @@ -151,9 +151,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index e61919b..8f52116 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -92,7 +92,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,8 +105,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx, } @@ -150,9 +150,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 46a04c8..0479ca0 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -94,7 +94,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -107,8 +107,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx } @@ -153,9 +153,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index aa8b2c0..017fcd5 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -93,7 +93,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -106,8 +106,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx } @@ -152,9 +152,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index c8f2f2f..46297c1 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -107,7 +107,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.side_data["embedding"]), self.side_data['embedding'][0], axis=0) def _load_data(self): - train_data, valid_data, test_data, npy_dict = self._load_raw_data() + train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -123,8 +123,8 @@ def _load_data(self): 'n_topic': len(self.topic2ind) + 1, 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': npy_dict['word2vec'], - 'copy_mask': npy_dict['copy_mask'], + 'word2vec': word2vec, + 'copy_mask': copy_mask, 'special_token_idx': self.special_token_idx } @@ -169,9 +169,7 @@ def _load_raw_data(self): processed_test_data = self.split_text(test_data) logger.info("[Finish test data split]") - npy_dict = {'word2vec': word_embedding, 'copy_mask': copy_mask} - - return processed_train_data, processed_valid_data, processed_test_data, npy_dict + return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask def _load_vocab(self): self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} From 699a537ff6480cb651e0abe4f09fa936f5521a6b Mon Sep 17 00:00:00 2001 From: txy77 Date: Sat, 15 Oct 2022 20:18:20 +0800 Subject: [PATCH 29/35] txy77 --- crslab/data/__init__.py | 10 +++++----- crslab/data/dataset/tokenizer/__init__.py | 10 +++++----- crslab/data/dataset/tokenizer/bert.py | 2 +- crslab/data/dataset/tokenizer/gpt2.py | 2 +- crslab/data/dataset/tokenizer/jieba.py | 2 +- crslab/data/dataset/tokenizer/nltk.py | 2 +- crslab/data/dataset/tokenizer/pkuseg.py | 6 +++--- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index 0c4200e..cb5fad5 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -25,11 +25,11 @@ from crslab.data.dataset.tokenizer import * tokenizer_register_table = { - 'nltk': nltk_tokenize, - 'jieba': jieba_tokenize, - 'gpt2': gpt2_tokenize, - 'bert': bert_tokenize, - 'pkuseg': pkuseg_tokenize + 'nltk': NltkTokenizer, + 'jieba': JiebaTokenizer, + 'gpt2': Gpt2Tokenizer, + 'bert': BertTokenizer, + 'pkuseg': PkusegTokenizer } dataset_register_table = { diff --git a/crslab/data/dataset/tokenizer/__init__.py b/crslab/data/dataset/tokenizer/__init__.py index bacd3aa..d5c67be 100644 --- a/crslab/data/dataset/tokenizer/__init__.py +++ b/crslab/data/dataset/tokenizer/__init__.py @@ -1,6 +1,6 @@ from .base import BaseTokenizer -from .bert import bert_tokenize -from .gpt2 import gpt2_tokenize -from .jieba import jieba_tokenize -from .nltk import nltk_tokenize -from .pkuseg import pkuseg_tokenize +from .bert import BertTokenizer +from .gpt2 import Gpt2Tokenizer +from .jieba import JiebaTokenizer +from .nltk import NltkTokenizer +from .pkuseg import PkusegTokenizer diff --git a/crslab/data/dataset/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py index 97d97ed..73f268a 100644 --- a/crslab/data/dataset/tokenizer/bert.py +++ b/crslab/data/dataset/tokenizer/bert.py @@ -7,7 +7,7 @@ from crslab.data.dataset.tokenizer.base import BaseTokenizer -class bert_tokenize(BaseTokenizer): +class BertTokenizer(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py index 081d2ca..196d238 100644 --- a/crslab/data/dataset/tokenizer/gpt2.py +++ b/crslab/data/dataset/tokenizer/gpt2.py @@ -7,7 +7,7 @@ from crslab.data.dataset.tokenizer.base import BaseTokenizer -class gpt2_tokenize(BaseTokenizer): +class Gpt2Tokenizer(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py index 6755e43..8354098 100644 --- a/crslab/data/dataset/tokenizer/jieba.py +++ b/crslab/data/dataset/tokenizer/jieba.py @@ -7,7 +7,7 @@ from crslab.data.dataset.tokenizer.base import BaseTokenizer -class jieba_tokenize(BaseTokenizer): +class JiebaTokenizer(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py index a5419c1..01ff902 100644 --- a/crslab/data/dataset/tokenizer/nltk.py +++ b/crslab/data/dataset/tokenizer/nltk.py @@ -8,7 +8,7 @@ from crslab.data.dataset.tokenizer.base import BaseTokenizer -class nltk_tokenize(BaseTokenizer): +class NltkTokenizer(BaseTokenizer): def __init__(self, path=None) -> None: self.special_token_idx = { diff --git a/crslab/data/dataset/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py index 9ab358f..c362f5f 100644 --- a/crslab/data/dataset/tokenizer/pkuseg.py +++ b/crslab/data/dataset/tokenizer/pkuseg.py @@ -7,10 +7,10 @@ from crslab.data.dataset.tokenizer.base import BaseTokenizer -class pkuseg_tokenize(BaseTokenizer): +class PkusegTokenizer(BaseTokenizer): def __init__(self, path=None) -> None: - self.pkuseg_tokenizer = pkuseg.pkuseg() + self.PkusegTokenizerr = pkuseg.pkuseg() self.special_token_idx = { 'pad': 0, 'start': 1, @@ -23,4 +23,4 @@ def __init__(self, path=None) -> None: super().__init__(path) def tokenize(self, text): - return self.pkuseg_tokenizer.cut(text) + return self.PkusegTokenizerr.cut(text) From b58aefa65effbd12f01669a2e0bd6e128d10c2cf Mon Sep 17 00:00:00 2001 From: txy77 Date: Sat, 15 Oct 2022 20:26:39 +0800 Subject: [PATCH 30/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 10 +++++----- crslab/data/dataset/gorecdial/gorecdial.py | 10 +++++----- crslab/data/dataset/inspired/inspired.py | 14 +++++++------- crslab/data/dataset/opendialkg/opendialkg.py | 10 +++++----- crslab/data/dataset/redial/redial.py | 10 +++++----- crslab/data/dataset/tgredial/tgredial.py | 10 +++++----- 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 8f78032..b713444 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -87,7 +87,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) @@ -329,7 +329,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -374,13 +374,13 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for item in dialog['item']: - list_word = self.Tokenizer.tokenize(item) + list_word = self.tokenizer.tokenize(item) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) for each_word in text: diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index e0f9aa4..1410106 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -87,7 +87,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) @@ -335,7 +335,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -384,14 +384,14 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = self.Tokenizer.tokenize(movie) + list_word = self.tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 8f52116..f36a86b 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -87,7 +87,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) @@ -334,7 +334,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -383,22 +383,22 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = self.Tokenizer.tokenize(movie) + list_word = self.tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word for genre in dialog['genre']: - list_word = self.Tokenizer.tokenize(genre) + list_word = self.tokenizer.tokenize(genre) match_list += list_word for people in dialog['people']: - list_word = self.Tokenizer.tokenize(people) + list_word = self.tokenizer.tokenize(people) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index 0479ca0..fd15109 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -89,7 +89,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) @@ -343,7 +343,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -392,14 +392,14 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word for item in dialog['item']: - list_word = self.Tokenizer.tokenize(item) + list_word = self.tokenizer.tokenize(item) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index 017fcd5..b137ba4 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -88,7 +88,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.special_token_idx = crs_tokenizer.special_token_idx self.unk_token_idx = self.special_token_idx['unk'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, "redial") super().__init__(opt, dpath, resource, restore, save) @@ -339,7 +339,7 @@ def split_text(self, data): each_data = [] for one in each['dialog']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['dialog'] = each_data @@ -388,13 +388,13 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for movie in dialog['movies']: - list_word = self.Tokenizer.tokenize(movie) + list_word = self.tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 46297c1..9419d09 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -93,7 +93,7 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): self.pad_topic_idx = self.special_token_idx['pad_topic'] self.tokenize = tokenize - self.Tokenizer = crs_tokenizer + self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'tgredial') self.replace_token = opt.get('replace_token', None) @@ -426,7 +426,7 @@ def split_text(self, data): each_dict['conv_id'] = each['conv_id'] for one in each['messages']: str_text = one['text'] - list_text = self.Tokenizer.tokenize(str_text) + list_text = self.tokenizer.tokenize(str_text) one['text'] = list_text each_data.append(one) each_dict['messages'] = each_data @@ -476,15 +476,15 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.Tokenizer.tokenize(word) + list_word = self.tokenizer.tokenize(word) match_list += list_word for movie in dialog['movie']: - list_word = self.Tokenizer.tokenize(movie) + list_word = self.tokenizer.tokenize(movie) match_list += list_word for entity in dialog['entity']: - list_word = self.Tokenizer.tokenize(entity) + list_word = self.tokenizer.tokenize(entity) match_list += list_word match_list = list(set(match_list)) From de8e3f3ec058d7e8dbe5e93c1cc6669d5cbc0abc Mon Sep 17 00:00:00 2001 From: txy77 Date: Tue, 18 Oct 2022 15:52:10 +0800 Subject: [PATCH 31/35] txy77 --- crslab/data/dataset/inspired/inspired.py | 1 + 1 file changed, 1 insertion(+) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index f36a86b..52d7965 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -130,6 +130,7 @@ def _load_raw_data(self): word_embedding = self.generate_word2vec(processed_train_data) logger.info('[Finish generate word2vec]') # build copy_mask + copy_mask = None if self.copy: copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) logger.info('[Finish generate copy_mask]') From 86a23d4282888ba59d8ad9f307e4b10276754234 Mon Sep 17 00:00:00 2001 From: txy77 Date: Mon, 14 Nov 2022 21:32:43 +0800 Subject: [PATCH 32/35] txy77 --- crslab/data/dataset/durecdial/durecdial.py | 18 +++++++------- crslab/data/dataset/gorecdial/gorecdial.py | 18 +++++++------- crslab/data/dataset/inspired/inspired.py | 26 ++++++++++---------- crslab/data/dataset/opendialkg/opendialkg.py | 18 +++++++------- crslab/data/dataset/redial/redial.py | 18 +++++++------- crslab/data/dataset/tgredial/tgredial.py | 18 +++++++------- 6 files changed, 58 insertions(+), 58 deletions(-) diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index b713444..16640a6 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -328,9 +328,9 @@ def split_text(self, data): each_dict = {} each_data = [] for one in each['dialog']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) @@ -374,14 +374,14 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for item in dialog['item']: - list_word = self.tokenizer.tokenize(item) - match_list += list_word + word_list = self.tokenizer.tokenize(item) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list match_list = list(set(match_list)) for each_word in text: if each_word in match_list: diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index 1410106..ae1120f 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -334,9 +334,9 @@ def split_text(self, data): each_dict = {} each_data = [] for one in each['dialog']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) @@ -384,15 +384,15 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for movie in dialog['movies']: - list_word = self.tokenizer.tokenize(movie) - match_list += list_word + word_list = self.tokenizer.tokenize(movie) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list match_list = list(set(match_list)) diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 52d7965..0be8da5 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -334,9 +334,9 @@ def split_text(self, data): each_dict = {} each_data = [] for one in each['dialog']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) @@ -384,23 +384,23 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for movie in dialog['movies']: - list_word = self.tokenizer.tokenize(movie) - match_list += list_word + word_list = self.tokenizer.tokenize(movie) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list for genre in dialog['genre']: - list_word = self.tokenizer.tokenize(genre) - match_list += list_word + word_list = self.tokenizer.tokenize(genre) + match_list += word_list for people in dialog['people']: - list_word = self.tokenizer.tokenize(people) - match_list += list_word + word_list = self.tokenizer.tokenize(people) + match_list += word_list match_list = list(set(match_list)) diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index fd15109..bfe689f 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -342,9 +342,9 @@ def split_text(self, data): each_dict = {} each_data = [] for one in each['dialog']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) @@ -392,15 +392,15 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list for item in dialog['item']: - list_word = self.tokenizer.tokenize(item) - match_list += list_word + word_list = self.tokenizer.tokenize(item) + match_list += word_list match_list = list(set(match_list)) diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index b137ba4..ec4d807 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -338,9 +338,9 @@ def split_text(self, data): each_dict = {} each_data = [] for one in each['dialog']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['dialog'] = each_data all_data.append(each_dict) @@ -388,14 +388,14 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for movie in dialog['movies']: - list_word = self.tokenizer.tokenize(movie) - match_list += list_word + word_list = self.tokenizer.tokenize(movie) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list match_list = list(set(match_list)) diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index 9419d09..c30bbd8 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -425,9 +425,9 @@ def split_text(self, data): each_data = [] each_dict['conv_id'] = each['conv_id'] for one in each['messages']: - str_text = one['text'] - list_text = self.tokenizer.tokenize(str_text) - one['text'] = list_text + text_str = one['text'] + text_list = self.tokenizer.tokenize(text_str) + one['text'] = text_list each_data.append(one) each_dict['messages'] = each_data each_dict['user_id'] = each['user_id'] @@ -476,16 +476,16 @@ def generate_copy_mask(self, tok2ind, processed_train_data): match_list = [] text = dialog['text'] for word in dialog['word']: - list_word = self.tokenizer.tokenize(word) - match_list += list_word + word_list = self.tokenizer.tokenize(word) + match_list += word_list for movie in dialog['movie']: - list_word = self.tokenizer.tokenize(movie) - match_list += list_word + word_list = self.tokenizer.tokenize(movie) + match_list += word_list for entity in dialog['entity']: - list_word = self.tokenizer.tokenize(entity) - match_list += list_word + word_list = self.tokenizer.tokenize(entity) + match_list += word_list match_list = list(set(match_list)) From f4c073e4ebc57f350cfa31bc409a9a37e22bc85f Mon Sep 17 00:00:00 2001 From: txy77 Date: Mon, 14 Nov 2022 22:26:10 +0800 Subject: [PATCH 33/35] txy77 --- crslab/data/dataset/gorecdial/gorecdial.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index ae1120f..a891596 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -400,6 +400,8 @@ def generate_copy_mask(self, tok2ind, processed_train_data): if each_word in match_list: token_id = tok2ind[each_word] copy_mask[token_id] = True + + return copy_mask def generate_word2vec(self, processed_train_data): From c7acee89a75cda99a80aa4a66c34814386fc5124 Mon Sep 17 00:00:00 2001 From: txy77 Date: Mon, 14 Nov 2022 22:30:02 +0800 Subject: [PATCH 34/35] txy77 --- crslab/data/dataset/gorecdial/gorecdial.py | 1 - 1 file changed, 1 deletion(-) diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index a891596..b9ea152 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -373,7 +373,6 @@ def generate_tok2ind(self, processed_train_data): if self.tokenize == 'nltk': tok2ind['_split_'] = cnt cnt += 1 - return tok2ind def generate_copy_mask(self, tok2ind, processed_train_data): From 4edd74ae03d58e93a76a27817ce16a532d92d2fa Mon Sep 17 00:00:00 2001 From: txy77 Date: Sat, 24 Jun 2023 16:50:42 +0800 Subject: [PATCH 35/35] txy77 --- LICENSE | 21 -- README.md | 289 +-------------- README_CN.md | 290 --------------- chatgpt_ask.sh | 1 + chatgpt_chat.sh | 1 + config/conversation/gpt2/durecdial.yaml | 6 +- config/conversation/gpt2/gorecdial.yaml | 6 +- config/conversation/gpt2/inspired.yaml | 4 - config/conversation/gpt2/opendialkg.yaml | 6 +- config/conversation/gpt2/redial.yaml | 6 +- config/conversation/gpt2/tgredial.yaml | 4 - .../conversation/transformer/durecdial.yaml | 2 +- .../conversation/transformer/gorecdial.yaml | 2 +- config/conversation/transformer/inspired.yaml | 2 +- .../conversation/transformer/opendialkg.yaml | 2 +- config/conversation/transformer/redial.yaml | 2 +- config/crs/inspired/durecdial.yaml | 7 - config/crs/inspired/gorecdial.yaml | 7 - config/crs/inspired/inspired.yaml | 7 - config/crs/inspired/opendialkg.yaml | 7 - config/crs/inspired/redial.yaml | 7 - config/crs/inspired/tgredial.yaml | 7 - config/crs/kbrd/durecdial.yaml | 2 +- config/crs/kbrd/gorecdial.yaml | 2 +- config/crs/kgsf/durecdial.yaml | 5 +- config/crs/kgsf/gorecdial.yaml | 5 +- config/crs/kgsf/inspired.yaml | 3 +- config/crs/kgsf/opendialkg.yaml | 5 +- config/crs/kgsf/redial.yaml | 5 +- config/crs/kgsf/tgredial.yaml | 5 +- config/crs/ntrd/tgredial.yaml | 5 +- config/crs/redial/opendialkg.yaml | 2 +- config/crs/redial/redial.yaml | 2 +- config/crs/redial/tgredial.yaml | 2 +- config/crs/tgredial/durecdial.yaml | 8 +- config/crs/tgredial/gorecdial.yaml | 8 +- config/crs/tgredial/inspired.yaml | 6 - config/crs/tgredial/opendialkg.yaml | 8 +- config/crs/tgredial/redial.yaml | 8 +- config/crs/tgredial/tgredial.yaml | 8 - config/iEvaLM/chatgpt/redial.yaml | 11 + config/policy/conv_bert/tgredial.yaml | 4 - config/policy/mgcg/tgredial.yaml | 2 - config/policy/pmi/tgredial.yaml | 4 +- config/policy/profile_bert/tgredial.yaml | 4 - config/policy/topic_bert/tgredial.yaml | 4 - config/recommendation/bert/durecdial.yaml | 6 +- config/recommendation/bert/gorecdial.yaml | 6 +- config/recommendation/bert/inspired.yaml | 4 - config/recommendation/bert/opendialkg.yaml | 6 +- config/recommendation/bert/redial.yaml | 6 +- config/recommendation/bert/tgredial.yaml | 4 - config/recommendation/gru4rec/durecdial.yaml | 4 +- config/recommendation/gru4rec/gorecdial.yaml | 4 +- config/recommendation/gru4rec/inspired.yaml | 2 - config/recommendation/gru4rec/opendialkg.yaml | 4 +- config/recommendation/gru4rec/redial.yaml | 4 +- config/recommendation/gru4rec/tgredial.yaml | 2 - .../recommendation/popularity/durecdial.yaml | 4 +- .../recommendation/popularity/gorecdial.yaml | 4 +- .../recommendation/popularity/inspired.yaml | 2 - .../recommendation/popularity/opendialkg.yaml | 4 +- config/recommendation/popularity/redial.yaml | 4 +- .../recommendation/popularity/tgredial.yaml | 4 +- config/recommendation/sasrec/durecdial.yaml | 4 +- config/recommendation/sasrec/gorecdial.yaml | 4 +- config/recommendation/sasrec/inspired.yaml | 2 - config/recommendation/sasrec/opendialkg.yaml | 4 +- config/recommendation/sasrec/redial.yaml | 4 +- config/recommendation/sasrec/tgredial.yaml | 2 - config/recommendation/textcnn/durecdial.yaml | 2 +- config/recommendation/textcnn/gorecdial.yaml | 2 +- config/recommendation/textcnn/opendialkg.yaml | 2 +- config/recommendation/textcnn/redial.yaml | 2 +- config/recommendation/textcnn/tgredial.yaml | 2 +- crslab/config/config.py | 2 +- crslab/data/__init__.py | 29 +- crslab/data/dataloader/__init__.py | 1 + crslab/data/dataloader/chatgpt.py | 60 +++ crslab/data/dataloader/inspired.py | 30 +- crslab/data/dataloader/kbrd.py | 31 +- crslab/data/dataloader/kgsf.py | 16 +- crslab/data/dataloader/ntrd.py | 19 +- crslab/data/dataloader/redial.py | 13 +- crslab/data/dataloader/tgredial.py | 40 +- crslab/data/dataset/base.py | 15 +- crslab/data/dataset/durecdial/durecdial.py | 195 ++-------- crslab/data/dataset/durecdial/resources.py | 63 +++- crslab/data/dataset/gorecdial/gorecdial.py | 210 ++--------- crslab/data/dataset/gorecdial/resources.py | 61 ++- crslab/data/dataset/inspired/inspired.py | 213 ++--------- crslab/data/dataset/inspired/resources.py | 59 ++- crslab/data/dataset/opendialkg/opendialkg.py | 218 ++--------- crslab/data/dataset/opendialkg/resources.py | 59 ++- crslab/data/dataset/redial/redial.py | 349 +++++++----------- crslab/data/dataset/redial/resources.py | 59 ++- crslab/data/dataset/tgredial/resources.py | 64 +++- crslab/data/dataset/tgredial/tgredial.py | 276 +++----------- crslab/data/dataset/tokenizer/__init__.py | 6 - crslab/data/dataset/tokenizer/base.py | 14 - crslab/data/dataset/tokenizer/bert.py | 28 -- crslab/data/dataset/tokenizer/gpt2.py | 30 -- crslab/data/dataset/tokenizer/jieba.py | 26 -- crslab/data/dataset/tokenizer/nltk.py | 27 -- crslab/data/dataset/tokenizer/pkuseg.py | 26 -- crslab/evaluator/ask.py | 297 +++++++++++++++ crslab/evaluator/chat.py | 275 ++++++++++++++ crslab/evaluator/embeddings.py | 13 +- crslab/evaluator/rec.py | 5 +- crslab/evaluator/standard.py | 49 +-- crslab/evaluator/utils.py | 21 ++ crslab/model/__init__.py | 3 +- crslab/model/conversation/gpt2/gpt2.py | 34 +- .../conversation/transformer/transformer.py | 18 +- crslab/model/crs/__init__.py | 1 + crslab/model/crs/chatgpt/__init__.py | 1 + crslab/model/crs/chatgpt/chatgpt.py | 279 ++++++++++++++ crslab/model/crs/inspired/inspired_conv.py | 36 +- crslab/model/crs/inspired/inspired_rec.py | 18 +- crslab/model/crs/kbrd/kbrd.py | 84 ++--- crslab/model/crs/kgsf/kgsf.py | 163 +++----- crslab/model/crs/kgsf/resources.py | 62 ++++ crslab/model/crs/ntrd/ntrd.py | 220 +++++------ crslab/model/crs/ntrd/resources.py | 62 ++++ crslab/model/crs/redial/modules.py | 53 +-- crslab/model/crs/redial/redial_conv.py | 10 +- crslab/model/crs/redial/redial_rec.py | 3 +- crslab/model/crs/tgredial/tg_conv.py | 34 +- crslab/model/crs/tgredial/tg_policy.py | 26 +- crslab/model/crs/tgredial/tg_rec.py | 25 +- crslab/model/policy/conv_bert/conv_bert.py | 20 +- crslab/model/policy/pmi/pmi.py | 3 +- .../model/policy/profile_bert/profile_bert.py | 20 +- crslab/model/policy/topic_bert/topic_bert.py | 20 +- crslab/model/pretrained_models.py | 64 ++++ crslab/model/recommendation/bert/bert.py | 18 +- crslab/quick_start/__init__.py | 2 +- crslab/quick_start/quick_start.py | 34 +- crslab/system/__init__.py | 4 +- crslab/system/chatgpt.py | 199 ++++++++++ crslab/system/inspired.py | 17 +- crslab/system/kbrd.py | 5 +- crslab/system/kgsf.py | 30 +- crslab/system/ntrd.py | 11 +- crslab/system/redial.py | 20 +- crslab/system/tgredial.py | 66 ++-- crslab/system/utils/functions.py | 8 + requirements.txt | 1 - run_crslab.py | 8 +- 149 files changed, 2533 insertions(+), 2956 deletions(-) delete mode 100644 LICENSE delete mode 100644 README_CN.md create mode 100644 chatgpt_ask.sh create mode 100644 chatgpt_chat.sh create mode 100644 config/iEvaLM/chatgpt/redial.yaml create mode 100644 crslab/data/dataloader/chatgpt.py delete mode 100644 crslab/data/dataset/tokenizer/__init__.py delete mode 100644 crslab/data/dataset/tokenizer/base.py delete mode 100644 crslab/data/dataset/tokenizer/bert.py delete mode 100644 crslab/data/dataset/tokenizer/gpt2.py delete mode 100644 crslab/data/dataset/tokenizer/jieba.py delete mode 100644 crslab/data/dataset/tokenizer/nltk.py delete mode 100644 crslab/data/dataset/tokenizer/pkuseg.py create mode 100644 crslab/evaluator/ask.py create mode 100644 crslab/evaluator/chat.py create mode 100644 crslab/model/crs/chatgpt/__init__.py create mode 100644 crslab/model/crs/chatgpt/chatgpt.py create mode 100644 crslab/model/crs/kgsf/resources.py create mode 100644 crslab/model/crs/ntrd/resources.py create mode 100644 crslab/model/pretrained_models.py create mode 100644 crslab/system/chatgpt.py diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 52a5811..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 RUCAIBox - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index a36f00f..093c818 100644 --- a/README.md +++ b/README.md @@ -1,286 +1,15 @@ # CRSLab -[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab) -[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases) -[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) -[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939) -[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest) +This branch integrates the new evaluation approach from iEvaLM, currently supporting ChatGPT models and ReDial dataset. We will complete the evaluation for all models and datasets in the future. Please continue to follow us! -[Paper](https://arxiv.org/pdf/2101.00939.pdf) | [Docs](https://crslab.readthedocs.io/en/latest/?badge=latest) -| [äø­ę–‡ē‰ˆ](./README_CN.md) +# Quick Start šŸš€ -**CRSLab** is an open-source toolkit for building Conversational Recommender System (CRS). It is developed based on -Python and PyTorch. CRSLab has the following highlights: - -- **Comprehensive benchmark models and datasets**: We have integrated commonly-used 6 datasets and 18 models, including graph neural network and pre-training models such as R-GCN, BERT and GPT-2. We have preprocessed these datasets to support these models, and release for downloading. -- **Extensive and standard evaluation protocols**: We support a series of widely-adopted evaluation protocols for testing and comparing different CRS. -- **General and extensible structure**: We design a general and extensible structure to unify various conversational recommendation datasets and models, in which we integrate various built-in interfaces and functions for quickly development. -- **Easy to get started**: We provide simple yet flexible configuration for new researchers to quickly start in our library. -- **Human-machine interaction interfaces**: We provide flexible human-machine interaction interfaces for researchers to conduct qualitative analysis. - -

- RecBole v0.1 architecture -
- Figure 1: The overall framework of CRSLab -

- - - - -- [Installation](#Installation) -- [Quick-Start](#Quick-Start) -- [Models](#Models) -- [Datasets](#Datasets) -- [Performance](#Performance) -- [Releases](#Releases) -- [Contributions](#Contributions) -- [Citing](#Citing) -- [Team](#Team) -- [License](#License) - - - -## Installation - -CRSLab works with the following operating systemsļ¼š - -- Linux -- Windows 10 -- macOS X - -CRSLab requires Python version 3.6 or later. - -CRSLab requires torch version 1.4.0 or later. If you want to use CRSLab with GPU, please ensure that CUDA or CUDAToolkit version is 9.2 or later. Please use the combinations shown in this [Link](https://pytorch-geometric.com/whl/) to ensure the normal operation of PyTorch Geometric. - - - -### Install PyTorch - -Use PyTorch [Locally Installation](https://pytorch.org/get-started/locally/) or [Previous Versions Installation](https://pytorch.org/get-started/previous-versions/) commands to install PyTorch. For example, on Linux and Windows 10: - -```bash -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` - -If you want to use CRSLab with GPU, make sure the following command prints `True` after installation: - -```bash -$ python -c "import torch; print(torch.cuda.is_available())" ->>> True -``` - - - -### Install PyTorch Geometric - -Ensure that at least PyTorch 1.4.0 is installed: - -```bash -$ python -c "import torch; print(torch.__version__)" ->>> 1.6.0 -``` - -Find the CUDA version PyTorch was installed with: - -```bash -$ python -c "import torch; print(torch.version.cuda)" ->>> 10.1 -``` - -Install the relevant packages: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-geometric -``` - -where `${CUDA}` and `${TORCH}` should be replaced by your specific CUDA version (`cpu`, `cu92`, `cu101`, `cu102`, `cu110`) and PyTorch version (`1.4.0`, `1.5.0`, `1.6.0`, `1.7.0`) respectively. For example, for PyTorch 1.6.0 and CUDA 10.1, type: - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-geometric -``` - - - -### Install CRSLab - -You can install from pip: - -```bash -pip install crslab -``` - -OR install from source: - -```bash -git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab -pip install -e . -``` - - - -## Quick-Start - -With the source code, you can use the provided script for initial usage of our library with cpu by default: - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml -``` - -The system will complete the data preprocessing, and training, validation, testing of each model in turn. Finally it will get the evaluation results of specified models. - -If you want to save pre-processed datasets and training results of models, you can use the following command: +## Modify your OpenAI API key šŸ”‘ +Please open the config/iEvaLM/chatgpt/redial.yaml file and replace the **your_api_key** with your own OpenAI API key. +## Evaluate šŸ¤” +You can use two types of interaction: attribute-based question answering and free-form chit-chat. ```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system -``` - -In summary, there are following arguments in `run_crslab.py`: - -- `--config` or `-c`: relative path for configuration file(yaml). -- `--gpu` or `-g`: specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1). -- `--save_data` or `-sd`: save pre-processed dataset. -- `--restore_data` or `-rd`: restore pre-processed dataset from file. -- `--save_system` or `-ss`: save trained system. -- `--restore_system` or `-rs`: restore trained system from file. -- `--debug` or `-d`: use validation dataset to debug your system. -- `--interact` or `-i`: interact with your system instead of training. -- `--tensorboard` or `-tb`: enable tensorboard to monitor train performance. - - - -## Models - -In CRSLab, we unify the task description of conversational recommendation into three sub-tasks, namely recommendation (recommend user-preferred items), conversation (generate proper responses) and policy (select proper interactive action). The recommendation and conversation sub-tasks are the core of a CRS and have been studied in most of works. The policy sub-task is needed by recent works, by which the CRS can interact with users through purposeful strategy. -As the first release version, we have implemented 18 models in the four categories of CRS model, Recommendation model, Conversation model and Policy model. - -| Category | Model | Graph Neural Network? | Pre-training Model? | -| :------------------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: | -| CRS Model | [ReDial](https://arxiv.org/abs/1812.07617)
[KBRD](https://arxiv.org/abs/1908.05391)
[KGSF](https://arxiv.org/abs/2007.04032)
[TG-ReDial](https://arxiv.org/abs/2010.04125)
[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) | Ɨ
āˆš
āˆš
Ɨ
Ɨ | Ɨ
Ɨ
Ɨ
āˆš
āˆš | -| Recommendation model | Popularity
[GRU4Rec](https://arxiv.org/abs/1511.06939)
[SASRec](https://arxiv.org/abs/1808.09781)
[TextCNN](https://arxiv.org/abs/1408.5882)
[R-GCN](https://arxiv.org/abs/1703.06103)
[BERT](https://arxiv.org/abs/1810.04805) | Ɨ
Ɨ
Ɨ
Ɨ
āˆš
Ɨ | Ɨ
Ɨ
Ɨ
Ɨ
Ɨ
āˆš | -| Conversation model | [HERD](https://arxiv.org/abs/1507.04808)
[Transformer](https://arxiv.org/abs/1706.03762)
[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) | Ɨ
Ɨ
Ɨ | Ɨ
Ɨ
āˆš | -| Policy model | PMI
[MGCG](https://arxiv.org/abs/2005.03954)
[Conv-BERT](https://arxiv.org/abs/2010.04125)
[Topic-BERT](https://arxiv.org/abs/2010.04125)
[Profile-BERT](https://arxiv.org/abs/2010.04125) | Ɨ
Ɨ
Ɨ
Ɨ
Ɨ | Ɨ
Ɨ
āˆš
āˆš
āˆš | - -Among them, the four CRS models integrate the recommendation model and the conversation model to improve each other, while others only specify an individual task. - -For Recommendation model and Conversation model, we have respectively implemented the following commonly-used automatic evaluation metrics: - -| Category | Metrics | -| :--------------------: | :----------------------------------------------------------: | -| Recommendation Metrics | Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50} | -| Conversation Metrics | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} | -| Policy Metrics | Accuracy, Hit@{1,3,5} | - - - -## Datasets - -We have collected and preprocessed 6 commonly-used human-annotated datasets, and each dataset was matched with proper KGs as shown below: - -| Dataset | Dialogs | Utterances | Domains | Task Definition | Entity KG | Word KG | -| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: | -| [ReDial](https://redialdata.github.io/website/) | 10,006 | 182,150 | Movie | -- | DBpedia | ConceptNet | -| [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial) | 10,000 | 129,392 | Movie | Topic Guide | CN-DBpedia | HowNet | -| [GoRecDial](https://arxiv.org/abs/1909.03922) | 9,125 | 170,904 | Movie | Action Choice | DBpedia | ConceptNet | -| [DuRecDial](https://arxiv.org/abs/2005.03954) | 10,200 | 156,000 | Movie, Music | Goal Plan | CN-DBpedia | HowNet | -| [INSPIRED](https://github.com/sweetpeach/Inspired) | 1,001 | 35,811 | Movie | Social Strategy | DBpedia | ConceptNet | -| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802 | 91,209 | Movie, Book | Path Generate | DBpedia | ConceptNet | - - - -## Performance - -We have trained and test the integrated models on the TG-Redial dataset, which is split into training, validation and test sets using a ratio of 8:1:1. For each conversation, we start from the first utterance, and generate reply utterances or recommendations in turn by our model. We perform the evaluation on the three sub-tasks. - -### Recommendation Task - -| Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | -| SASRec | 0.000446 | 0.00134 | 0.0160 | 0.000446 | 0.000576 | 0.00114 | 0.000445 | 0.00075 | 0.00380 | -| TextCNN | 0.00267 | 0.0103 | 0.0236 | 0.00267 | 0.00434 | 0.00493 | 0.00267 | 0.00570 | 0.00860 | -| BERT | 0.00722 | 0.00490 | 0.0281 | 0.00722 | 0.0106 | 0.0124 | 0.00490 | 0.0147 | 0.0239 | -| KBRD | 0.00401 | 0.0254 | 0.0588 | 0.00401 | 0.00891 | 0.0103 | 0.00401 | 0.0127 | 0.0198 | -| KGSF | 0.00535 | **0.0285** | **0.0771** | 0.00535 | 0.0114 | **0.0135** | 0.00535 | **0.0154** | **0.0259** | -| TG-ReDial | **0.00793** | 0.0251 | 0.0524 | **0.00793** | **0.0122** | 0.0134 | **0.00793** | 0.0152 | 0.0211 | - - -### Conversation Task - -| Model | BLEU@1 | BLEU@2 | BLEU@3 | BLEU@4 | Dist@1 | Dist@2 | Dist@3 | Dist@4 | Average | Extreme | Greedy | PPL | -| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: | -| HERD | 0.120 | 0.0141 | 0.00136 | 0.000350 | 0.181 | 0.369 | 0.847 | 1.30 | 0.697 | 0.382 | 0.639 | 472 | -| Transformer | 0.266 | 0.0440 | 0.0145 | 0.00651 | 0.324 | 0.837 | 2.02 | 3.06 | 0.879 | 0.438 | 0.680 | 30.9 | -| GPT2 | 0.0858 | 0.0119 | 0.00377 | 0.0110 | **2.35** | **4.62** | **8.84** | **12.5** | 0.763 | 0.297 | 0.583 | 9.26 | -| KBRD | 0.267 | 0.0458 | 0.0134 | 0.00579 | 0.469 | 1.50 | 3.40 | 4.90 | 0.863 | 0.398 | 0.710 | 52.5 | -| KGSF | **0.383** | **0.115** | **0.0444** | **0.0200** | 0.340 | 0.910 | 3.50 | 6.20 | **0.888** | **0.477** | **0.767** | 50.1 | -| TG-ReDial | 0.125 | 0.0204 | 0.00354 | 0.000803 | 0.881 | 1.75 | 7.00 | 12.0 | 0.810 | 0.332 | 0.598 | **7.41** | - - -### Policy Task - -| Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | -| MGCG | 0.591 | 0.818 | 0.883 | 0.591 | 0.680 | 0.683 | 0.591 | 0.712 | 0.729 | -| Conv-BERT | 0.597 | 0.814 | 0.881 | 0.597 | 0.684 | 0.687 | 0.597 | 0.716 | 0.731 | -| Topic-BERT | 0.598 | 0.828 | 0.885 | 0.598 | 0.690 | 0.693 | 0.598 | 0.724 | 0.737 | -| TG-ReDial | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** | - -The above results were obtained from our CRSLab in preliminary experiments. However, these algorithms were implemented and tuned based on our understanding and experiences, which may not achieve their optimal performance. If you could yield a better result for some specific algorithm, please kindly let us know. We will update this table after the results are verified. - -## Releases - -| Releases | Date | Features | -| :------: | :-----------: | :----------: | -| v0.1.1 | 1 / 4 / 2021 | Basic CRSLab | -| v0.1.2 | 3 / 28 / 2021 | CRSLab | - - - -## Contributions - -Please let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/RUCAIBox/CRSLab/issues). - -We welcome all contributions from bug fixes to new features and extensions. - -We expect all contributions discussed in the issue tracker and going through PRs. - -We thank the nice contributions through PRs from [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang). - - - -## Citing - -If you find CRSLab useful for your research or development, please cite our [Paper](https://arxiv.org/pdf/2101.00939.pdf): - -``` -@article{crslab, - title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System}, - author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen}, - year={2021}, - journal={arXiv preprint arXiv:2101.00939} -} -``` - - - -## Team - -**CRSLab** was developed and maintained by [AI Box](http://aibox.ruc.edu.cn/) group in RUC. - - - -## License - -**CRSLab** uses [MIT License](./LICENSE). - +bash chatgpt_chat.sh # free-form chit-chat +bash chatgpt_ask.sh # attribute-based question answering +``` \ No newline at end of file diff --git a/README_CN.md b/README_CN.md deleted file mode 100644 index 8290a4a..0000000 --- a/README_CN.md +++ /dev/null @@ -1,290 +0,0 @@ -# CRSLab - -[![Pypi Latest Version](https://img.shields.io/pypi/v/crslab)](https://pypi.org/project/crslab) -[![Release](https://img.shields.io/github/v/release/rucaibox/crslab.svg)](https://github.com/rucaibox/crslab/releases) -[![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) -[![arXiv](https://img.shields.io/badge/arXiv-CRSLab-%23B21B1B)](https://arxiv.org/abs/2101.00939) -[![Documentation Status](https://readthedocs.org/projects/crslab/badge/?version=latest)](https://crslab.readthedocs.io/en/latest/?badge=latest) - -[č®ŗꖇ](https://arxiv.org/pdf/2101.00939.pdf) | [ę–‡ę”£](https://crslab.readthedocs.io/en/latest/?badge=latest) -| [English Version](./README.md) - -**CRSLab** ę˜Æäø€äøŖē”ØäŗŽęž„å»ŗåƹčƝęŽØ荐ē³»ē»Ÿļ¼ˆCRSļ¼‰ēš„å¼€ęŗå·„具包ļ¼Œå…¶åŸŗäŗŽ PyTorch 实ēŽ°ć€äø»č¦é¢å‘ē ”ē©¶č€…ä½æē”Øļ¼Œå¹¶å…·ęœ‰å¦‚äø‹ē‰¹č‰²ļ¼š - -- **å…Ø面ēš„åŸŗ准ęØ”åž‹å’Œę•°ę®é›†**ļ¼šęˆ‘ä»¬é›†ęˆäŗ†åøøē”Øēš„ 6 äøŖę•°ę®é›†å’Œ 18 äøŖęؔ型ļ¼ŒåŒ…ꋬåŸŗäŗŽå›¾ē„žē»ē½‘ē»œå’Œé¢„č®­ē»ƒęؔ型ļ¼ŒęƔ如 GCNļ¼ŒBERT 和 GPT-2ļ¼›ęˆ‘们čæ˜åÆ¹ę•°ę®é›†čæ›č”Œē›ø关处ē†ä»„ę”Æꌁčæ™äŗ›ęؔ型ļ¼Œå¹¶ęä¾›é¢„处ē†åŽēš„ē‰ˆęœ¬ä¾›å¤§å®¶äø‹č½½ć€‚ -- **å¤§č§„ęØ”ēš„ę ‡å‡†čÆ„ęµ‹**ļ¼šęˆ‘们ę”Æꌁäø€ē³»åˆ—č¢«å¹æę³›č®¤åÆēš„čÆ„ä¼°ę–¹å¼ę„ęµ‹čƕ和ęÆ”č¾ƒäøåŒēš„ CRS怂 -- **通ē”Ø和åÆę‰©å±•ēš„ē»“ęž„**ļ¼šęˆ‘ä»¬č®¾č®”äŗ†é€šē”Ø和åÆę‰©å±•ēš„ē»“ęž„ę„ē»Ÿäø€å„ē§åƹčƝęŽØčę•°ę®é›†å’Œęؔ型ļ¼Œå¹¶é›†ęˆäŗ†å¤šē§å†…ē½®ęŽ„å£å’Œå‡½ę•°ä»„ä¾æäŗŽåæ«é€Ÿå¼€å‘怂 -- **ä¾æę·ēš„ä½æē”Øę–¹ę³•**ļ¼šęˆ‘们äøŗę–°ę‰‹ęä¾›äŗ†ē®€å•č€Œēµę“»ēš„配ē½®ļ¼Œę–¹ä¾æ其åæ«é€ŸåÆåŠØ集ꈐåœØ CRSLab äø­ēš„ęØ”åž‹ć€‚ -- **äŗŗę€§åŒ–ēš„äŗŗęœŗäŗ¤äŗ’ꎄ口**ļ¼šęˆ‘ä»¬ęä¾›äŗ†äŗŗę€§åŒ–ēš„äŗŗęœŗäŗ¤äŗ’ē•Œé¢ļ¼Œä»„供ē ”ē©¶č€…åƹęÆ”å’Œęµ‹čƕäøåŒēš„ęؔ型ē³»ē»Ÿć€‚ - -

- RecBole v0.1 architecture -
- 图ē‰‡: CRSLab ēš„ę€»ä½“ęž¶ęž„ -

- - - - -- [å®‰č£…](#å®‰č£…) -- [åæ«é€ŸäøŠę‰‹](#åæ«é€ŸäøŠę‰‹) -- [ęؔ型](#ęؔ型) -- [ę•°ę®é›†](#ę•°ę®é›†) -- [čÆ„ęµ‹ē»“ęžœ](#čÆ„ęµ‹ē»“ęžœ) -- [å‘č”Œē‰ˆęœ¬](#å‘č”Œē‰ˆęœ¬) -- [č“”ēŒ®](#č“”ēŒ®) -- [引ē”Ø](#引ē”Ø) -- [锹ē›®å›¢é˜Ÿ](#锹ē›®å›¢é˜Ÿ) -- [å…č“£å£°ę˜Ž](#å…č“£å£°ę˜Ž) - - - -## å®‰č£… - -CRSLab åÆ仄åœØ仄äø‹å‡ ē§ē³»ē»ŸäøŠčæč”Œļ¼š - -- Linux -- Windows 10 -- macOS X - -CRSLab éœ€č¦åœØ Python 3.6 ęˆ–ę›“é«˜ēš„ēŽÆ境äø‹čæč”Œć€‚ - -CRSLab 要걂 torch ē‰ˆęœ¬åœØ 1.4.0 及仄äøŠļ¼Œå¦‚ęžœä½ ęƒ³åœØ GPU äøŠčæč”Œ CRSLabļ¼ŒčÆ·ē”®äæä½ ēš„ CUDA ē‰ˆęœ¬ęˆ–者 CUDAToolkit ē‰ˆęœ¬åœØ 9.2 及仄äøŠć€‚äøŗäæčƁ PyTorch Geometric åŗ“ēš„ę­£åøøčæč”Œļ¼ŒčÆ·ä½æē”Ø[é“¾ęŽ„](https://pytorch-geometric.com/whl/)ꉀē¤ŗēš„å®‰č£…ę–¹å¼ć€‚ - - - -### å®‰č£… PyTorch - -ä½æē”Ø PyTorch [ęœ¬åœ°å®‰č£…](https://pytorch.org/get-started/locally/)å‘½ä»¤ęˆ–č€…[先前ē‰ˆęœ¬å®‰č£…](https://pytorch.org/get-started/previous-versions/)å‘½ä»¤å®‰č£… PyTorchļ¼ŒęƔ如åœØ Linux 和 Windows äø‹ļ¼š - -```bash -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` - -å®‰č£…å®ŒęˆåŽļ¼Œå¦‚ęžœä½ ęƒ³åœØ GPU äøŠčæč”Œ CRSLabļ¼ŒčÆ·ē”®äæå¦‚äø‹å‘½ä»¤č¾“å‡ŗ`True`ļ¼š - -```bash -$ python -c "import torch; print(torch.cuda.is_available())" ->>> True -``` - - - -### å®‰č£… PyTorch Geometric - -ē”®äæå®‰č£…ēš„ PyTorch ē‰ˆęœ¬č‡³å°‘äøŗ 1.4.0ļ¼š - -```bash -$ python -c "import torch; print(torch.__version__)" ->>> 1.6.0 -``` - -ę‰¾åˆ°å®‰č£…å„½ēš„ PyTorch åƹåŗ”ēš„ CUDA ē‰ˆęœ¬ļ¼š - -```bash -$ python -c "import torch; print(torch.version.cuda)" ->>> 10.1 -``` - -å®‰č£…ē›ø关ēš„包ļ¼š - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html -pip install torch-geometric -``` - -其äø­`${CUDA}`和`${TORCH}`åŗ”ä½æē”Øē”®å®šēš„ CUDA ē‰ˆęœ¬ļ¼ˆ`cpu`ļ¼Œ`cu92`ļ¼Œ`cu101`ļ¼Œ`cu102`ļ¼Œ`cu110`ļ¼‰å’Œ PyTorch ē‰ˆęœ¬ļ¼ˆ`1.4.0`ļ¼Œ`1.5.0`ļ¼Œ`1.6.0`ļ¼Œ`1.7.0`ļ¼‰ę„分别ę›æę¢ć€‚ęƔ如ļ¼ŒåƹäŗŽ PyTorch 1.6.0 和 CUDA 10.1ļ¼Œč¾“å…„ļ¼š - -```bash -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html -pip install torch-geometric -``` - - - -### å®‰č£… CRSLab - -ä½ åÆ仄通čæ‡ pip ę„å®‰č£…ļ¼š - -```bash -pip install crslab -``` - -也åÆ仄通čæ‡ęŗę–‡ä»¶čæ›č”Œčæ›č”Œå®‰č£…ļ¼š - -```bash -git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab -pip install -e . -``` - - - -## åæ«é€ŸäøŠę‰‹ - -从 GitHub äø‹č½½ CRSLab 后ļ¼ŒåÆ仄ä½æē”Øęä¾›ēš„č„šęœ¬åæ«é€Ÿčæč”Œå’Œęµ‹čƕļ¼Œé»˜č®¤ä½æē”ØCPUļ¼š - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml -``` - -ē³»ē»Ÿå°†ä¾ę¬”å®Œęˆę•°ę®ēš„预处ē†ļ¼Œä»„及各ęؔ块ēš„č®­ē»ƒć€éŖŒčÆå’Œęµ‹čƕļ¼Œå¹¶å¾—åˆ°ęŒ‡å®šēš„ęؔ型čÆ„ęµ‹ē»“ęžœć€‚ - -å¦‚ęžœä½ åøŒęœ›äæå­˜ę•°ę®é¢„处ē†ē»“ęžœäøŽęØ”åž‹č®­ē»ƒē»“ęžœļ¼ŒåÆ仄ä½æē”Ø如äø‹å‘½ä»¤ļ¼š - -```bash -python run_crslab.py --config config/crs/kgsf/redial.yaml --save_data --save_system -``` - -ꀻēš„ę„čÆ“ļ¼Œ`run_crslab.py`ęœ‰å¦‚äø‹å‚ę•°åÆä¾›č°ƒē”Øļ¼š - -- `--config` ꈖ `-c`ļ¼šé…ē½®ę–‡ä»¶ēš„ē›øåƹč·Æ径ļ¼Œä»„ęŒ‡å®ščæč”Œēš„ęؔ型äøŽę•°ę®é›†ć€‚ -- `--gpu` or `-g`ļ¼šęŒ‡å®š GPU idļ¼Œę”ÆęŒå¤š GPUļ¼Œé»˜č®¤ä½æē”Ø CPUļ¼ˆ-1ļ¼‰ć€‚ -- `--save_data` ꈖ `-sd`ļ¼šäæå­˜é¢„处ē†ēš„ę•°ę®ć€‚ -- `--restore_data` ꈖ `-rd`ļ¼šä»Žę–‡ä»¶čƻ取预处ē†ēš„ę•°ę®ć€‚ -- `--save_system` ꈖ `-ss`ļ¼šäæå­˜č®­ē»ƒå„½ēš„ CRS ē³»ē»Ÿć€‚ -- `--restore_system` ꈖ `-rs`ļ¼šä»Žę–‡ä»¶č½½å…„ęå‰č®­ē»ƒå„½ēš„ē³»ē»Ÿć€‚ -- `--debug` ꈖ `-d`ļ¼šē”ØéŖŒčƁ集代ę›æč®­ē»ƒé›†ä»„ę–¹ä¾æ调čÆ•ć€‚ -- `--interact` ꈖ `-i`ļ¼šäøŽä½ ēš„ē³»ē»Ÿčæ›č”ŒåƹčƝäŗ¤äŗ’ļ¼Œč€Œéžčæ›č”Œč®­ē»ƒć€‚ -- `--tensorboard` or `-tb`ļ¼šä½æē”Ø tensorboardX ē»„ä»¶ę„ē›‘굋训ē»ƒč”ØēŽ°ć€‚ - - - -## ęؔ型 - -åœØē¬¬äø€äøŖå‘č”Œē‰ˆäø­ļ¼Œęˆ‘们实ēŽ°äŗ† 4 ē±»å…± 18 äøŖęØ”åž‹ć€‚čæ™é‡Œęˆ‘们将åƹčƝęŽØčä»»åŠ”äø»č¦ę‹†åˆ†ęˆäø‰äøŖ任劔ļ¼šęŽØčä»»åŠ”ļ¼ˆē”ŸęˆęŽØ荐ēš„商品ļ¼‰ļ¼ŒåƹčƝ任劔ļ¼ˆē”ŸęˆåƹčƝēš„回复ļ¼‰å’Œē­–ē•„任劔ļ¼ˆč§„划åƹčƝęŽØ荐ēš„ē­–ē•„ļ¼‰ć€‚å…¶äø­ę‰€ęœ‰ēš„åƹčƝęŽØ荐ē³»ē»Ÿéƒ½å…·ęœ‰åƹčƝ和ęŽØčä»»åŠ”ļ¼Œä»–们ę˜ÆåƹčƝęŽØ荐ē³»ē»Ÿēš„ę øåæƒåŠŸčƒ½ć€‚č€Œē­–ē•„任劔ę˜Æäø€äøŖč¾…åŠ©ä»»åŠ”ļ¼Œå…¶č‡“力äŗŽę›“儽ēš„ęŽ§åˆ¶åƹčƝęŽØ荐ē³»ē»Ÿļ¼ŒåœØäøåŒēš„ęؔ型äø­ēš„实ēŽ°ä¹ŸåÆčƒ½äøåŒļ¼ˆå¦‚ TG-ReDial 采ē”Øäø€äøŖäø»é¢˜é¢„굋ęؔ型ļ¼ŒDuRecDial äø­é‡‡ē”Øäø€äøŖåƹčÆč§„åˆ’ęؔ型ē­‰ļ¼‰ļ¼š - - - -| ē±»åˆ« | ęؔ型 | Graph Neural Network? | Pre-training Model? | -| :------: | :----------------------------------------------------------: | :-----------------------------: | :-----------------------------: | -| CRS ęؔ型 | [ReDial](https://arxiv.org/abs/1812.07617)
[KBRD](https://arxiv.org/abs/1908.05391)
[KGSF](https://arxiv.org/abs/2007.04032)
[TG-ReDial](https://arxiv.org/abs/2010.04125)
[INSPIRED](https://www.aclweb.org/anthology/2020.emnlp-main.654.pdf) | Ɨ
āˆš
āˆš
Ɨ
Ɨ | Ɨ
Ɨ
Ɨ
āˆš
āˆš | -| ęŽØ荐ęؔ型 | Popularity
[GRU4Rec](https://arxiv.org/abs/1511.06939)
[SASRec](https://arxiv.org/abs/1808.09781)
[TextCNN](https://arxiv.org/abs/1408.5882)
[R-GCN](https://arxiv.org/abs/1703.06103)
[BERT](https://arxiv.org/abs/1810.04805) | Ɨ
Ɨ
Ɨ
Ɨ
āˆš
Ɨ | Ɨ
Ɨ
Ɨ
Ɨ
Ɨ
āˆš | -| åƹčƝęؔ型 | [HERD](https://arxiv.org/abs/1507.04808)
[Transformer](https://arxiv.org/abs/1706.03762)
[GPT-2](http://www.persagen.com/files/misc/radford2019language.pdf) | Ɨ
Ɨ
Ɨ | Ɨ
Ɨ
āˆš | -| ē­–ē•„ęؔ型 | PMI
[MGCG](https://arxiv.org/abs/2005.03954)
[Conv-BERT](https://arxiv.org/abs/2010.04125)
[Topic-BERT](https://arxiv.org/abs/2010.04125)
[Profile-BERT](https://arxiv.org/abs/2010.04125) | Ɨ
Ɨ
Ɨ
Ɨ
Ɨ | Ɨ
Ɨ
āˆš
āˆš
āˆš | - - -其äø­ļ¼ŒCRS ęؔ型ę˜Æꌇē›“ęŽ„čžåˆęŽØ荐ęؔ型和åƹčƝęؔ型ļ¼Œä»„ē›øäŗ’增å¼ŗå½¼ę­¤ēš„ę•ˆęžœļ¼Œę•…其内éƒØ往往已ē»åŒ…含äŗ†ęŽØ荐态åƹčƝ和ē­–ē•„ęØ”åž‹ć€‚å…¶ä»–å¦‚ęŽØ荐ęØ”åž‹ć€åƹčƝęØ”åž‹ć€ē­–ē•„ęؔ型往往åŖ关ę³Ø仄äøŠä»»åŠ”äø­ēš„ęŸäø€äøŖ怂 - -ęˆ‘ä»¬åƹäŗŽčæ™å‡ ē±»ęؔ型ļ¼Œęˆ‘们čæ˜åˆ†åˆ«å®žēŽ°äŗ†å¦‚äø‹ēš„č‡ŖåŠØčÆ„ęµ‹ęŒ‡ę ‡ęؔ块ļ¼š - -| ē±»åˆ« | ꌇꠇ | -| :------: | :----------------------------------------------------------: | -| ęŽØ荐ꌇꠇ | Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50} | -| åƹčÆęŒ‡ę ‡ | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} | -| ē­–ē•„ꌇꠇ | Accuracy, Hit@{1,3,5} | - - - - - -## ę•°ę®é›† - -ęˆ‘ä»¬ę”¶é›†äŗ† 6 äøŖåøøē”Øēš„äŗŗå·„ę ‡ę³Øę•°ę®é›†ļ¼Œå¹¶åƹ它们čæ›č”Œäŗ†é¢„处ē†ļ¼ˆåŒ…ę‹¬å¼•å…„外éƒØēŸ„čÆ†å›¾č°±ļ¼‰ļ¼Œä»„čžå…„ē»Ÿäø€ēš„ CRS 任劔äø­ć€‚如äø‹äøŗē›øå…³ę•°ę®é›†ēš„ē»Ÿč®”ę•°ę®ļ¼š - -| Dataset | Dialogs | Utterances | Domains | Task Definition | Entity KG | Word KG | -| :----------------------------------------------------------: | :-----: | :--------: | :----------: | :-------------: | :--------: | :--------: | -| [ReDial](https://redialdata.github.io/website/) | 10,006 | 182,150 | Movie | -- | DBpedia | ConceptNet | -| [TG-ReDial](https://github.com/RUCAIBox/TG-ReDial) | 10,000 | 129,392 | Movie | Topic Guide | CN-DBpedia | HowNet | -| [GoRecDial](https://arxiv.org/abs/1909.03922) | 9,125 | 170,904 | Movie | Action Choice | DBpedia | ConceptNet | -| [DuRecDial](https://arxiv.org/abs/2005.03954) | 10,200 | 156,000 | Movie, Music | Goal Plan | CN-DBpedia | HowNet | -| [INSPIRED](https://github.com/sweetpeach/Inspired) | 1,001 | 35,811 | Movie | Social Strategy | DBpedia | ConceptNet | -| [OpenDialKG](https://github.com/facebookresearch/opendialkg) | 13,802 | 91,209 | Movie, Book | Path Generate | DBpedia | ConceptNet | - - - -## čÆ„ęµ‹ē»“ęžœ - -ęˆ‘ä»¬åœØ TG-ReDial ę•°ę®é›†äøŠåƹęؔ型čæ›č”Œäŗ†č®­ē»ƒå’Œęµ‹čƕļ¼Œčæ™é‡Œęˆ‘ä»¬å°†ę•°ę®é›†ęŒ‰ē…§ 8:1:1 åˆ‡åˆ†ć€‚å…¶äø­åƹäŗŽęÆę”ę•°ę®ļ¼Œęˆ‘们从åƹčƝēš„ē¬¬äø€č½®å¼€å§‹ļ¼Œäø€č½®äø€č½®ēš„čæ›č”ŒęŽØ荐态ē­–ē•„ē”Ÿęˆć€å›žå¤ē”Ÿęˆä»»åŠ”怂äø‹č”Øč®°å½•äŗ†ē›ø关ēš„čÆ„ęµ‹ē»“ęžœć€‚ - -### ęŽØčä»»åŠ” - -| ęؔ型 | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :-------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | :---------: | :--------: | :--------: | -| SASRec | 0.000446 | 0.00134 | 0.0160 | 0.000446 | 0.000576 | 0.00114 | 0.000445 | 0.00075 | 0.00380 | -| TextCNN | 0.00267 | 0.0103 | 0.0236 | 0.00267 | 0.00434 | 0.00493 | 0.00267 | 0.00570 | 0.00860 | -| BERT | 0.00722 | 0.00490 | 0.0281 | 0.00722 | 0.0106 | 0.0124 | 0.00490 | 0.0147 | 0.0239 | -| KBRD | 0.00401 | 0.0254 | 0.0588 | 0.00401 | 0.00891 | 0.0103 | 0.00401 | 0.0127 | 0.0198 | -| KGSF | 0.00535 | **0.0285** | **0.0771** | 0.00535 | 0.0114 | **0.0135** | 0.00535 | **0.0154** | **0.0259** | -| TG-ReDial | **0.00793** | 0.0251 | 0.0524 | **0.00793** | **0.0122** | 0.0134 | **0.00793** | 0.0152 | 0.0211 | - - - -### åƹčƝ任劔 - -| ęؔ型 | BLEU@1 | BLEU@2 | BLEU@3 | BLEU@4 | Dist@1 | Dist@2 | Dist@3 | Dist@4 | Average | Extreme | Greedy | PPL | -| :---------: | :-------: | :-------: | :--------: | :--------: | :------: | :------: | :------: | :------: | :-------: | :-------: | :-------: | :------: | -| HERD | 0.120 | 0.0141 | 0.00136 | 0.000350 | 0.181 | 0.369 | 0.847 | 1.30 | 0.697 | 0.382 | 0.639 | 472 | -| Transformer | 0.266 | 0.0440 | 0.0145 | 0.00651 | 0.324 | 0.837 | 2.02 | 3.06 | 0.879 | 0.438 | 0.680 | 30.9 | -| GPT2 | 0.0858 | 0.0119 | 0.00377 | 0.0110 | **2.35** | **4.62** | **8.84** | **12.5** | 0.763 | 0.297 | 0.583 | 9.26 | -| KBRD | 0.267 | 0.0458 | 0.0134 | 0.00579 | 0.469 | 1.50 | 3.40 | 4.90 | 0.863 | 0.398 | 0.710 | 52.5 | -| KGSF | **0.383** | **0.115** | **0.0444** | **0.0200** | 0.340 | 0.910 | 3.50 | 6.20 | **0.888** | **0.477** | **0.767** | 50.1 | -| TG-ReDial | 0.125 | 0.0204 | 0.00354 | 0.000803 | 0.881 | 1.75 | 7.00 | 12.0 | 0.810 | 0.332 | 0.598 | **7.41** | - - - -### ē­–ē•„任劔 - -| ęؔ型 | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 | -| :--------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | -| MGCG | 0.591 | 0.818 | 0.883 | 0.591 | 0.680 | 0.683 | 0.591 | 0.712 | 0.729 | -| Conv-BERT | 0.597 | 0.814 | 0.881 | 0.597 | 0.684 | 0.687 | 0.597 | 0.716 | 0.731 | -| Topic-BERT | 0.598 | 0.828 | 0.885 | 0.598 | 0.690 | 0.693 | 0.598 | 0.724 | 0.737 | -| TG-ReDial | **0.600** | **0.830** | **0.893** | **0.600** | **0.693** | **0.696** | **0.600** | **0.727** | **0.741** | - -äøŠčæ°ē»“ęžœę˜Æęˆ‘ä»¬ä½æē”Ø CRSLab čæ›č”Œå®žéŖŒå¾—到ēš„怂ē„¶č€Œļ¼Œčæ™äŗ›ē®—ę³•ę˜Æę ¹ę®ęˆ‘ä»¬ēš„ē»éŖŒå’Œē†č§£ę„实ēŽ°å’Œč°ƒå‚ēš„ļ¼ŒåÆčƒ½čæ˜ę²”ęœ‰č¾¾åˆ°å®ƒä»¬ēš„ęœ€ä½³ę€§čƒ½ć€‚å¦‚ęžœę‚Øčƒ½åœØꟐäøŖ具体ē®—ę³•äøŠå¾—åˆ°ę›“儽ēš„ē»“ęžœļ¼ŒčƷ告ēŸ„ęˆ‘ä»¬ć€‚éŖŒčƁē»“ęžœåŽļ¼Œęˆ‘ä»¬ä¼šę›“ꖰčÆ„č”Ø怂 - -## å‘č”Œē‰ˆęœ¬ - -| ē‰ˆęœ¬å· | å‘č”Œę—„ęœŸ | ē‰¹ę€§ | -| :----: | :-----------: | :----------: | -| v0.1.1 | 1 / 4 / 2021 | Basic CRSLab | -| v0.1.2 | 3 / 28 / 2021 | CRSLab | - - - -## č“”ēŒ® - -å¦‚ęžœę‚Ø遇到错čÆÆęˆ–ęœ‰ä»»ä½•å»ŗč®®ļ¼ŒčƷ通čæ‡ [Issue](https://github.com/RUCAIBox/CRSLab/issues) čæ›č”Œåé¦ˆ - -ęˆ‘ä»¬ę¬¢čæŽå…³äŗŽäæ®å¤é”™čÆÆć€ę·»åŠ ę–°ē‰¹ę€§ēš„ä»»ä½•č“”ēŒ®ć€‚ - -å¦‚ęžœęƒ³č“”ēŒ®ä»£ē ļ¼ŒčƷ先åœØ Issue äø­ęå‡ŗ问题ļ¼Œē„¶åŽå†ę PR怂 - -ęˆ‘ä»¬ę„Ÿč°¢ [@shubaoyu](https://github.com/shubaoyu), [@ToheartZhang](https://github.com/ToheartZhang) 通čæ‡ PR äøŗ锹ē›®č“”ēŒ®ēš„ę–°ē‰¹ę€§ć€‚ - - - -## 引ē”Ø - -å¦‚ęžœä½ č§‰å¾— CRSLab åƹ你ēš„ē§‘ē ”å·„ä½œęœ‰åø®åŠ©ļ¼ŒčƷ引ē”Øęˆ‘ä»¬ēš„[č®ŗꖇ](https://arxiv.org/pdf/2101.00939.pdf)ļ¼š - -``` -@article{crslab, - title={CRSLab: An Open-Source Toolkit for Building Conversational Recommender System}, - author={Kun Zhou, Xiaolei Wang, Yuanhang Zhou, Chenzhan Shang, Yuan Cheng, Wayne Xin Zhao, Yaliang Li, Ji-Rong Wen}, - year={2021}, - journal={arXiv preprint arXiv:2101.00939} -} -``` - - - -## 锹ē›®å›¢é˜Ÿ - -**CRSLab** ē”±äø­å›½äŗŗę°‘å¤§å­¦ [AI Box](http://aibox.ruc.edu.cn/) 小ē»„开发和ē»“ꊤ怂 - - - -## å…č“£å£°ę˜Ž - -**CRSLab** åŸŗäŗŽ [MIT License](./LICENSE) čæ›č”Œå¼€å‘ļ¼Œęœ¬é”¹ē›®ēš„ę‰€ęœ‰ę•°ę®å’Œä»£ē åŖčƒ½č¢«ē”ØäŗŽå­¦ęœÆē›®ēš„怂 diff --git a/chatgpt_ask.sh b/chatgpt_ask.sh new file mode 100644 index 0000000..89f3d35 --- /dev/null +++ b/chatgpt_ask.sh @@ -0,0 +1 @@ +python run_crslab.py --config config/iEvaLM/chatgpt/redial.yaml --mode ask \ No newline at end of file diff --git a/chatgpt_chat.sh b/chatgpt_chat.sh new file mode 100644 index 0000000..4e8478c --- /dev/null +++ b/chatgpt_chat.sh @@ -0,0 +1 @@ +python run_crslab.py --config config/iEvaLM/chatgpt/redial.yaml --mode chat \ No newline at end of file diff --git a/config/conversation/gpt2/durecdial.yaml b/config/conversation/gpt2/durecdial.yaml index c3aa503..5762297 100644 --- a/config/conversation/gpt2/durecdial.yaml +++ b/config/conversation/gpt2/durecdial.yaml @@ -2,17 +2,13 @@ dataset: DuRecDial tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'GPT2-chitchat' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/gorecdial.yaml b/config/conversation/gpt2/gorecdial.yaml index 4623943..7dd0a6d 100644 --- a/config/conversation/gpt2/gorecdial.yaml +++ b/config/conversation/gpt2/gorecdial.yaml @@ -2,17 +2,13 @@ dataset: GoRecDial tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/inspired.yaml b/config/conversation/gpt2/inspired.yaml index f150198..b620579 100644 --- a/config/conversation/gpt2/inspired.yaml +++ b/config/conversation/gpt2/inspired.yaml @@ -2,8 +2,6 @@ dataset: Inspired tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/opendialkg.yaml b/config/conversation/gpt2/opendialkg.yaml index 091fa9f..bf93096 100644 --- a/config/conversation/gpt2/opendialkg.yaml +++ b/config/conversation/gpt2/opendialkg.yaml @@ -2,17 +2,13 @@ dataset: OpenDialKG tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/redial.yaml b/config/conversation/gpt2/redial.yaml index 07b4e2b..d6db3d0 100644 --- a/config/conversation/gpt2/redial.yaml +++ b/config/conversation/gpt2/redial.yaml @@ -2,17 +2,13 @@ dataset: ReDial tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'gpt2' # optim conv: epoch: 1 diff --git a/config/conversation/gpt2/tgredial.yaml b/config/conversation/gpt2/tgredial.yaml index f747e2e..378d9af 100644 --- a/config/conversation/gpt2/tgredial.yaml +++ b/config/conversation/gpt2/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: conv: gpt2 -# tokenize path -conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model conv_model: GPT2 -# pretrained path -conv_pretrained_path: 'GPT2-chitchat' # optim conv: epoch: 50 diff --git a/config/conversation/transformer/durecdial.yaml b/config/conversation/transformer/durecdial.yaml index c61973c..a9f92fe 100644 --- a/config/conversation/transformer/durecdial.yaml +++ b/config/conversation/transformer/durecdial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/gorecdial.yaml b/config/conversation/transformer/gorecdial.yaml index c05ddb4..fd578c5 100644 --- a/config/conversation/transformer/gorecdial.yaml +++ b/config/conversation/transformer/gorecdial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/inspired.yaml b/config/conversation/transformer/inspired.yaml index 90b86d8..fc6ea05 100644 --- a/config/conversation/transformer/inspired.yaml +++ b/config/conversation/transformer/inspired.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/opendialkg.yaml b/config/conversation/transformer/opendialkg.yaml index 2971704..208adce 100644 --- a/config/conversation/transformer/opendialkg.yaml +++ b/config/conversation/transformer/opendialkg.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/conversation/transformer/redial.yaml b/config/conversation/transformer/redial.yaml index 25f2ab4..ae1c7c6 100644 --- a/config/conversation/transformer/redial.yaml +++ b/config/conversation/transformer/redial.yaml @@ -5,7 +5,7 @@ tokenize: # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model conv_model: Transformer token_emb_dim: 300 diff --git a/config/crs/inspired/durecdial.yaml b/config/crs/inspired/durecdial.yaml index cc8fa80..6984c40 100644 --- a/config/crs/inspired/durecdial.yaml +++ b/config/crs/inspired/durecdial.yaml @@ -3,9 +3,6 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-chinese' -conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -rec_pretrained_path: 'bert-base-chinese' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'GPT2-chitchat' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/gorecdial.yaml b/config/crs/inspired/gorecdial.yaml index 250edc2..e44800b 100644 --- a/config/crs/inspired/gorecdial.yaml +++ b/config/crs/inspired/gorecdial.yaml @@ -3,9 +3,6 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/inspired.yaml b/config/crs/inspired/inspired.yaml index 0c1887a..a992737 100644 --- a/config/crs/inspired/inspired.yaml +++ b/config/crs/inspired/inspired.yaml @@ -3,9 +3,6 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'gpt2' # optim rec: epoch: 1 diff --git a/config/crs/inspired/opendialkg.yaml b/config/crs/inspired/opendialkg.yaml index eb440d9..ff3c13a 100644 --- a/config/crs/inspired/opendialkg.yaml +++ b/config/crs/inspired/opendialkg.yaml @@ -3,9 +3,6 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -conv_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/redial.yaml b/config/crs/inspired/redial.yaml index 48c3112..df25019 100644 --- a/config/crs/inspired/redial.yaml +++ b/config/crs/inspired/redial.yaml @@ -3,9 +3,6 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -conv_pretrained_path: 'bert-base-uncased' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'gpt2' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/inspired/tgredial.yaml b/config/crs/inspired/tgredial.yaml index cfca337..892eb20 100644 --- a/config/crs/inspired/tgredial.yaml +++ b/config/crs/inspired/tgredial.yaml @@ -3,9 +3,6 @@ dataset: TGReDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-chinese' -conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,12 +11,8 @@ scale: 1 # model # rec rec_model: InspiredRec -# pretrained path -rec_pretrained_path: 'bert-base-chinese' # conv conv_model: InspiredConv -# pretrained path -conv_pretrained_path: 'GPT2-chitchat' # embedding: word2vec embedding_dim: 300 use_dropout: False diff --git a/config/crs/kbrd/durecdial.yaml b/config/crs/kbrd/durecdial.yaml index 4a115f2..fb62d7e 100644 --- a/config/crs/kbrd/durecdial.yaml +++ b/config/crs/kbrd/durecdial.yaml @@ -4,7 +4,7 @@ tokenize: jieba # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.1 +scale: 1 # model model: KBRD token_emb_dim: 300 diff --git a/config/crs/kbrd/gorecdial.yaml b/config/crs/kbrd/gorecdial.yaml index be78096..98360a5 100644 --- a/config/crs/kbrd/gorecdial.yaml +++ b/config/crs/kbrd/gorecdial.yaml @@ -4,7 +4,7 @@ tokenize: nltk # dataloader context_truncate: 1024 response_truncate: 1024 -scale: 0.01 +scale: 1 # model model: KBRD token_emb_dim: 300 diff --git a/config/crs/kgsf/durecdial.yaml b/config/crs/kgsf/durecdial.yaml index bd97fe1..481eb1f 100644 --- a/config/crs/kgsf/durecdial.yaml +++ b/config/crs/kgsf/durecdial.yaml @@ -1,11 +1,11 @@ # dataset dataset: DuRecDial tokenize: jieba -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/gorecdial.yaml b/config/crs/kgsf/gorecdial.yaml index 8c90308..b9b1ad1 100644 --- a/config/crs/kgsf/gorecdial.yaml +++ b/config/crs/kgsf/gorecdial.yaml @@ -1,11 +1,11 @@ # dataset dataset: GoRecDial tokenize: nltk -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/inspired.yaml b/config/crs/kgsf/inspired.yaml index 79f9f22..f087ca3 100644 --- a/config/crs/kgsf/inspired.yaml +++ b/config/crs/kgsf/inspired.yaml @@ -1,7 +1,7 @@ # dataset dataset: Inspired tokenize: nltk -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/opendialkg.yaml b/config/crs/kgsf/opendialkg.yaml index a3ff91e..5d3df93 100644 --- a/config/crs/kgsf/opendialkg.yaml +++ b/config/crs/kgsf/opendialkg.yaml @@ -1,11 +1,11 @@ # dataset dataset: OpenDialKG tokenize: nltk -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 0.01 +scale: 1 # model model: KGSF token_emb_dim: 300 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 1 diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index 26b1b11..9b235da 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -1,11 +1,11 @@ # dataset dataset: ReDial tokenize: nltk -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: KGSF token_emb_dim: 300 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 3 diff --git a/config/crs/kgsf/tgredial.yaml b/config/crs/kgsf/tgredial.yaml index 6718d19..981dbcf 100644 --- a/config/crs/kgsf/tgredial.yaml +++ b/config/crs/kgsf/tgredial.yaml @@ -1,11 +1,11 @@ # dataset dataset: TGReDial tokenize: pkuseg -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: KGSF token_emb_dim: 300 @@ -21,7 +21,6 @@ learn_positional_embeddings: false embeddings_scale: true reduction: false n_positions: 1024 -copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/ntrd/tgredial.yaml b/config/crs/ntrd/tgredial.yaml index 7193b0b..91f38c7 100644 --- a/config/crs/ntrd/tgredial.yaml +++ b/config/crs/ntrd/tgredial.yaml @@ -1,11 +1,11 @@ # dataset dataset: TGReDial tokenize: pkuseg -embedding: True +embedding: word2vec.npy # dataloader context_truncate: 256 response_truncate: 30 -scale: 1.0 +scale: 1 # model model: NTRD token_emb_dim: 300 @@ -24,7 +24,6 @@ n_positions: 1024 gen_loss_weight: 5 n_movies: 62287 replace_token: '[ITEM]' -copy: true # optim pretrain: epoch: 50 diff --git a/config/crs/redial/opendialkg.yaml b/config/crs/redial/opendialkg.yaml index 061616b..7cf0637 100644 --- a/config/crs/redial/opendialkg.yaml +++ b/config/crs/redial/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/redial/redial.yaml b/config/crs/redial/redial.yaml index c6e848d..2cf569f 100644 --- a/config/crs/redial/redial.yaml +++ b/config/crs/redial/redial.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/redial/tgredial.yaml b/config/crs/redial/tgredial.yaml index a87d3c5..fa388f9 100644 --- a/config/crs/redial/tgredial.yaml +++ b/config/crs/redial/tgredial.yaml @@ -6,7 +6,7 @@ tokenize: # dataloader utterance_truncate: 80 conversation_truncate: 40 -scale: 0.01 +scale: 1 # model # rec rec_model: ReDialRec diff --git a/config/crs/tgredial/durecdial.yaml b/config/crs/tgredial/durecdial.yaml index 8085430..0bd4e9e 100644 --- a/config/crs/tgredial/durecdial.yaml +++ b/config/crs/tgredial/durecdial.yaml @@ -3,20 +3,14 @@ dataset: DuRecDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-chinese' -conv_tokenize_path: 'GPT2-chitchat' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv -# pretrained path -rec_pretrained_path: 'bert-base-chinese' -conv_pretrained_path: 'GPT2-chitchat' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/gorecdial.yaml b/config/crs/tgredial/gorecdial.yaml index 8e1e982..f766def 100644 --- a/config/crs/tgredial/gorecdial.yaml +++ b/config/crs/tgredial/gorecdial.yaml @@ -3,20 +3,14 @@ dataset: GoRecDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv -# pretrained path -rec_pretrained_path: 'bert-base-uncased' -conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/inspired.yaml b/config/crs/tgredial/inspired.yaml index c7a4a12..f4ace12 100644 --- a/config/crs/tgredial/inspired.yaml +++ b/config/crs/tgredial/inspired.yaml @@ -3,9 +3,6 @@ dataset: Inspired tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 @@ -14,9 +11,6 @@ scale: 1 # model rec_model: TGRec conv_model: TGConv -# pretrained path -rec_pretrained_path: 'bert-base-uncased' -conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/opendialkg.yaml b/config/crs/tgredial/opendialkg.yaml index 15ee74f..563bbcc 100644 --- a/config/crs/tgredial/opendialkg.yaml +++ b/config/crs/tgredial/opendialkg.yaml @@ -3,20 +3,14 @@ dataset: OpenDialKG tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv -# pretrained path -rec_pretrained_path: 'bert-base-uncased' -conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/redial.yaml b/config/crs/tgredial/redial.yaml index ae5f037..60943d8 100644 --- a/config/crs/tgredial/redial.yaml +++ b/config/crs/tgredial/redial.yaml @@ -3,20 +3,14 @@ dataset: ReDial tokenize: rec: bert conv: gpt2 -# tokenize path -rec_tokenize_path: 'bert-base-uncased' -conv_tokenize_path: 'gpt2' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TGRec conv_model: TGConv -# pretrained path -rec_pretrained_path: 'bert-base-uncased' -conv_pretrained_path: 'gpt2' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/crs/tgredial/tgredial.yaml b/config/crs/tgredial/tgredial.yaml index 3a8081b..0e1c956 100644 --- a/config/crs/tgredial/tgredial.yaml +++ b/config/crs/tgredial/tgredial.yaml @@ -4,10 +4,6 @@ tokenize: rec: bert conv: gpt2 policy: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' -conv_tokenize_path: 'GPT2-chitchat' -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -17,10 +13,6 @@ scale: 1 rec_model: TGRec conv_model: TGConv policy_model: TGPolicy -# pretrained path -rec_pretrained_path: 'bert-base-chinese' -conv_pretrained_path: 'GPT2-chitchat' -policy_pretrained_path: 'bert-base-chinese' hidden_dropout_prob: 0.2 initializer_range: 0.02 hidden_size: 50 diff --git a/config/iEvaLM/chatgpt/redial.yaml b/config/iEvaLM/chatgpt/redial.yaml new file mode 100644 index 0000000..01367fc --- /dev/null +++ b/config/iEvaLM/chatgpt/redial.yaml @@ -0,0 +1,11 @@ +model: ChatGPT +tokenize: nltk +dataset: ReDial +api_key: your_api_key +turn_num: 5 +rec: + batch_size: 1 +conv: + batch_size: 1 +cache_item: + batch_size: 1000 \ No newline at end of file diff --git a/config/policy/conv_bert/tgredial.yaml b/config/policy/conv_bert/tgredial.yaml index 393c0b6..78e5c58 100644 --- a/config/policy/conv_bert/tgredial.yaml +++ b/config/policy/conv_bert/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: policy: bert -# tokenize path -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model policy_model: ConvBERT -# pretrained path -policy_pretrained_path: 'bert-base-chinese' # optim policy: epoch: 50 diff --git a/config/policy/mgcg/tgredial.yaml b/config/policy/mgcg/tgredial.yaml index 5bb42f0..7cd78ec 100644 --- a/config/policy/mgcg/tgredial.yaml +++ b/config/policy/mgcg/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: policy: bert -# tokenize path -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/policy/pmi/tgredial.yaml b/config/policy/pmi/tgredial.yaml index 7658e86..f6ba96c 100644 --- a/config/policy/pmi/tgredial.yaml +++ b/config/policy/pmi/tgredial.yaml @@ -2,13 +2,11 @@ dataset: TGReDial tokenize: policy: bert -# tokenize path -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model policy_model: PMI # optim diff --git a/config/policy/profile_bert/tgredial.yaml b/config/policy/profile_bert/tgredial.yaml index 9a35942..39f9ae8 100644 --- a/config/policy/profile_bert/tgredial.yaml +++ b/config/policy/profile_bert/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: policy: bert -# tokenize path -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model policy_model: ProfileBERT -# pretrained path -policy_pretrained_path: 'bert-base-chinese' n_sent: 10 # optim policy: diff --git a/config/policy/topic_bert/tgredial.yaml b/config/policy/topic_bert/tgredial.yaml index 6884468..c3a5253 100644 --- a/config/policy/topic_bert/tgredial.yaml +++ b/config/policy/topic_bert/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: policy: bert -# tokenize path -policy_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model policy_model: TopicBERT -# pretrained path -policy_pretrained_path: 'bert-base-chinese' # optim policy: epoch: 50 diff --git a/config/recommendation/bert/durecdial.yaml b/config/recommendation/bert/durecdial.yaml index b0edb13..051e8d1 100644 --- a/config/recommendation/bert/durecdial.yaml +++ b/config/recommendation/bert/durecdial.yaml @@ -2,17 +2,13 @@ dataset: DuRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-chinese' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/gorecdial.yaml b/config/recommendation/bert/gorecdial.yaml index 727f15e..cdbb30e 100644 --- a/config/recommendation/bert/gorecdial.yaml +++ b/config/recommendation/bert/gorecdial.yaml @@ -2,17 +2,13 @@ dataset: GoRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/inspired.yaml b/config/recommendation/bert/inspired.yaml index b492f0d..d2d9d18 100644 --- a/config/recommendation/bert/inspired.yaml +++ b/config/recommendation/bert/inspired.yaml @@ -2,8 +2,6 @@ dataset: Inspired tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/opendialkg.yaml b/config/recommendation/bert/opendialkg.yaml index 7972304..ae9f12f 100644 --- a/config/recommendation/bert/opendialkg.yaml +++ b/config/recommendation/bert/opendialkg.yaml @@ -2,17 +2,13 @@ dataset: OpenDialKG tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/redial.yaml b/config/recommendation/bert/redial.yaml index 696bd96..48d664c 100644 --- a/config/recommendation/bert/redial.yaml +++ b/config/recommendation/bert/redial.yaml @@ -2,17 +2,13 @@ dataset: ReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-uncased' # optim rec: epoch: 1 diff --git a/config/recommendation/bert/tgredial.yaml b/config/recommendation/bert/tgredial.yaml index 679667d..717a2ab 100644 --- a/config/recommendation/bert/tgredial.yaml +++ b/config/recommendation/bert/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 @@ -11,8 +9,6 @@ item_truncate: 100 scale: 1 # model rec_model: BERT -# pretrained path -rec_pretrained_path: 'bert-base-chinese' # optim rec: epoch: 20 diff --git a/config/recommendation/gru4rec/durecdial.yaml b/config/recommendation/gru4rec/durecdial.yaml index ca29808..aa73472 100644 --- a/config/recommendation/gru4rec/durecdial.yaml +++ b/config/recommendation/gru4rec/durecdial.yaml @@ -2,13 +2,11 @@ dataset: DuRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/gorecdial.yaml b/config/recommendation/gru4rec/gorecdial.yaml index 814ebd6..5d48dd5 100644 --- a/config/recommendation/gru4rec/gorecdial.yaml +++ b/config/recommendation/gru4rec/gorecdial.yaml @@ -2,13 +2,11 @@ dataset: GoRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.1 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/inspired.yaml b/config/recommendation/gru4rec/inspired.yaml index f2508db..8ef81fe 100644 --- a/config/recommendation/gru4rec/inspired.yaml +++ b/config/recommendation/gru4rec/inspired.yaml @@ -2,8 +2,6 @@ dataset: Inspired tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/gru4rec/opendialkg.yaml b/config/recommendation/gru4rec/opendialkg.yaml index 8cd8865..3aebaf8 100644 --- a/config/recommendation/gru4rec/opendialkg.yaml +++ b/config/recommendation/gru4rec/opendialkg.yaml @@ -2,13 +2,11 @@ dataset: OpenDialKG tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/redial.yaml b/config/recommendation/gru4rec/redial.yaml index 1af0f3f..b4f9c75 100644 --- a/config/recommendation/gru4rec/redial.yaml +++ b/config/recommendation/gru4rec/redial.yaml @@ -2,13 +2,11 @@ dataset: ReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: GRU4REC gru_hidden_size: 50 diff --git a/config/recommendation/gru4rec/tgredial.yaml b/config/recommendation/gru4rec/tgredial.yaml index 8bac5d8..7caf3d0 100644 --- a/config/recommendation/gru4rec/tgredial.yaml +++ b/config/recommendation/gru4rec/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/durecdial.yaml b/config/recommendation/popularity/durecdial.yaml index cb05935..ca8759a 100644 --- a/config/recommendation/popularity/durecdial.yaml +++ b/config/recommendation/popularity/durecdial.yaml @@ -2,13 +2,11 @@ dataset: DuRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/gorecdial.yaml b/config/recommendation/popularity/gorecdial.yaml index 187d552..e6334b0 100644 --- a/config/recommendation/popularity/gorecdial.yaml +++ b/config/recommendation/popularity/gorecdial.yaml @@ -2,13 +2,11 @@ dataset: GoRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/inspired.yaml b/config/recommendation/popularity/inspired.yaml index 20714e8..4c9a821 100644 --- a/config/recommendation/popularity/inspired.yaml +++ b/config/recommendation/popularity/inspired.yaml @@ -2,8 +2,6 @@ dataset: Inspired tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/popularity/opendialkg.yaml b/config/recommendation/popularity/opendialkg.yaml index cbc319f..3ff691b 100644 --- a/config/recommendation/popularity/opendialkg.yaml +++ b/config/recommendation/popularity/opendialkg.yaml @@ -2,13 +2,11 @@ dataset: OpenDialKG tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/redial.yaml b/config/recommendation/popularity/redial.yaml index 2265d2c..27ad004 100644 --- a/config/recommendation/popularity/redial.yaml +++ b/config/recommendation/popularity/redial.yaml @@ -2,13 +2,11 @@ dataset: ReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/popularity/tgredial.yaml b/config/recommendation/popularity/tgredial.yaml index 973d1e4..95b9247 100644 --- a/config/recommendation/popularity/tgredial.yaml +++ b/config/recommendation/popularity/tgredial.yaml @@ -2,13 +2,11 @@ dataset: TGReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: Popularity # optim diff --git a/config/recommendation/sasrec/durecdial.yaml b/config/recommendation/sasrec/durecdial.yaml index ac1fcfe..2a32693 100644 --- a/config/recommendation/sasrec/durecdial.yaml +++ b/config/recommendation/sasrec/durecdial.yaml @@ -2,13 +2,11 @@ dataset: DuRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/gorecdial.yaml b/config/recommendation/sasrec/gorecdial.yaml index 1c88d58..98b6dfe 100644 --- a/config/recommendation/sasrec/gorecdial.yaml +++ b/config/recommendation/sasrec/gorecdial.yaml @@ -2,13 +2,11 @@ dataset: GoRecDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/inspired.yaml b/config/recommendation/sasrec/inspired.yaml index 7c1ab9f..d79ff24 100644 --- a/config/recommendation/sasrec/inspired.yaml +++ b/config/recommendation/sasrec/inspired.yaml @@ -2,8 +2,6 @@ dataset: Inspired tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/sasrec/opendialkg.yaml b/config/recommendation/sasrec/opendialkg.yaml index c380597..0efb55e 100644 --- a/config/recommendation/sasrec/opendialkg.yaml +++ b/config/recommendation/sasrec/opendialkg.yaml @@ -2,13 +2,11 @@ dataset: OpenDialKG tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/redial.yaml b/config/recommendation/sasrec/redial.yaml index 1809617..2dea48e 100644 --- a/config/recommendation/sasrec/redial.yaml +++ b/config/recommendation/sasrec/redial.yaml @@ -2,13 +2,11 @@ dataset: ReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-uncased' # dataloader context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: SASREC hidden_dropout_prob: 0.2 diff --git a/config/recommendation/sasrec/tgredial.yaml b/config/recommendation/sasrec/tgredial.yaml index 0751b7c..9888002 100644 --- a/config/recommendation/sasrec/tgredial.yaml +++ b/config/recommendation/sasrec/tgredial.yaml @@ -2,8 +2,6 @@ dataset: TGReDial tokenize: rec: bert -# tokenize path -rec_tokenize_path: 'bert-base-chinese' # dataloader context_truncate: 256 response_truncate: 30 diff --git a/config/recommendation/textcnn/durecdial.yaml b/config/recommendation/textcnn/durecdial.yaml index 3040132..c244237 100644 --- a/config/recommendation/textcnn/durecdial.yaml +++ b/config/recommendation/textcnn/durecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/gorecdial.yaml b/config/recommendation/textcnn/gorecdial.yaml index 6791031..f2fc0d5 100644 --- a/config/recommendation/textcnn/gorecdial.yaml +++ b/config/recommendation/textcnn/gorecdial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/opendialkg.yaml b/config/recommendation/textcnn/opendialkg.yaml index 88eba55..5f9972b 100644 --- a/config/recommendation/textcnn/opendialkg.yaml +++ b/config/recommendation/textcnn/opendialkg.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/redial.yaml b/config/recommendation/textcnn/redial.yaml index 0466f51..6c43c0a 100644 --- a/config/recommendation/textcnn/redial.yaml +++ b/config/recommendation/textcnn/redial.yaml @@ -6,7 +6,7 @@ tokenize: context_truncate: 256 response_truncate: 30 item_truncate: 100 -scale: 0.01 +scale: 1 # model rec_model: TextCNN hidden_dropout_prob: 0.2 diff --git a/config/recommendation/textcnn/tgredial.yaml b/config/recommendation/textcnn/tgredial.yaml index 0de66df..0d5c708 100644 --- a/config/recommendation/textcnn/tgredial.yaml +++ b/config/recommendation/textcnn/tgredial.yaml @@ -1,7 +1,7 @@ # dataset dataset: TGReDial tokenize: - rec: jieba + rec: sougou # dataloader context_truncate: 256 response_truncate: 30 diff --git a/crslab/config/config.py b/crslab/config/config.py index 3ba6fe7..1bcd061 100644 --- a/crslab/config/config.py +++ b/crslab/config/config.py @@ -38,7 +38,7 @@ def __init__(self, config_file, gpu='-1', debug=False): # gpu os.environ['CUDA_VISIBLE_DEVICES'] = gpu if gpu != '-1': - self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] + self.opt['gpu'] = [i for i in gpu.split(',')] else: self.opt['gpu'] = [-1] # dataset diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index cb5fad5..299d4e5 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -22,15 +22,6 @@ from crslab.data.dataloader import * from crslab.data.dataset import * -from crslab.data.dataset.tokenizer import * - -tokenizer_register_table = { - 'nltk': NltkTokenizer, - 'jieba': JiebaTokenizer, - 'gpt2': Gpt2Tokenizer, - 'bert': BertTokenizer, - 'pkuseg': PkusegTokenizer -} dataset_register_table = { 'ReDial': ReDialDataset, @@ -75,18 +66,12 @@ 'ProfileBERT': TGReDialDataLoader, 'MGCG': TGReDialDataLoader, 'PMI': TGReDialDataLoader, - 'NTRD': NTRDDataLoader + 'NTRD': NTRDDataLoader, + 'ChatGPT': ChatGPTDataLoader } -def get_tokenizer(tokenize, path=None) -> BaseTokenizer: - """ - get tokenizer from opt - """ - return tokenizer_register_table[tokenize](path) - - -def get_dataset(opt, tokenize, crs_tokenizer, restore, save) -> BaseDataset: +def get_dataset(opt, tokenize, restore, save) -> BaseDataset: """get and process dataset Args: @@ -101,10 +86,9 @@ def get_dataset(opt, tokenize, crs_tokenizer, restore, save) -> BaseDataset: """ dataset = opt['dataset'] if dataset in dataset_register_table: - return dataset_register_table[dataset](opt, tokenize, crs_tokenizer, restore, save) + return dataset_register_table[dataset](opt, tokenize, restore, save) else: - raise NotImplementedError( - f'The dataloader [{dataset}] has not been implemented') + raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: @@ -123,5 +107,4 @@ def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: if model_name in dataloader_register_table: return dataloader_register_table[model_name](opt, dataset, vocab) else: - raise NotImplementedError( - f'The dataloader [{model_name}] has not been implemented') + raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented') diff --git a/crslab/data/dataloader/__init__.py b/crslab/data/dataloader/__init__.py index 7b4ce12..dc83a89 100644 --- a/crslab/data/dataloader/__init__.py +++ b/crslab/data/dataloader/__init__.py @@ -5,3 +5,4 @@ from .redial import ReDialDataLoader from .tgredial import TGReDialDataLoader from .ntrd import NTRDDataLoader +from .chatgpt import ChatGPTDataLoader diff --git a/crslab/data/dataloader/chatgpt.py b/crslab/data/dataloader/chatgpt.py new file mode 100644 index 0000000..41dd82b --- /dev/null +++ b/crslab/data/dataloader/chatgpt.py @@ -0,0 +1,60 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +from tqdm import tqdm + +from crslab.data.dataloader.base import BaseDataLoader + +class ChatGPTDataLoader(BaseDataLoader): + + def __init__(self, opt, dataset, vocab): + super().__init__(opt, dataset) + + def process_fn(self): + augment_dataset = [] + for conv_dict in tqdm(self.dataset): + if conv_dict['role'] == 'Recommender' and len(conv_dict['items']) > 0: + augment_conv_dict = { + 'dialog_id': conv_dict['dialog_id'], + 'role': conv_dict['role'], + 'entity': conv_dict['context_entities'], + 'context': conv_dict['context'], + 'item': conv_dict['items'] + } + augment_dataset.append(augment_conv_dict) + return augment_dataset + + def batchify(self, batch): + batch_dialog_id = [] + batch_role = [] + batch_context = [] + batch_movies = [] + batch_entities = [] + + for conv_dict in batch: + batch_dialog_id.append(conv_dict['dialog_id']) + batch_role.append(conv_dict['role']) + batch_context.append(conv_dict['context']) + batch_movies.append(conv_dict['item']) + batch_entities.append(conv_dict['entity']) + + return { + 'dialog_id': batch_dialog_id, + 'role': batch_role, + 'context': batch_context, + 'item': batch_movies, + 'entity': batch_entities + } + + def rec_process_fn(self): + return self.process_fn() + + def rec_batchify(self, batch): + return self.rec_batchify(batch) + + def conv_process_fn(self): + return self.process_fn() + + def conv_batchify(self, batch): + return self.batchify(batch) \ No newline at end of file diff --git a/crslab/data/dataloader/inspired.py b/crslab/data/dataloader/inspired.py index a5983e6..3881113 100644 --- a/crslab/data/dataloader/inspired.py +++ b/crslab/data/dataloader/inspired.py @@ -5,11 +5,11 @@ from copy import deepcopy import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, - padded_tensor, truncate) from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt + class InspiredDataLoader(BaseDataLoader): """Dataloader for model Inspired. @@ -56,20 +56,20 @@ def __init__(self, opt, dataset, vocab): super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.unk_token_idx = vocab['special_token_idx']['unk'] - self.conv_bos_id = vocab['special_token_idx']['start'] - self.cls_id = vocab['special_token_idx']['start'] - self.sep_id = vocab['special_token_idx']['end'] - if 'sent_split' in vocab['special_token_idx']: - self.sent_split_idx = vocab['special_token_idx']['sent_split'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.unk_token_idx = vocab['unk'] + self.conv_bos_id = vocab['start'] + self.cls_id = vocab['start'] + self.sep_id = vocab['end'] + if 'sent_split' in vocab: + self.sent_split_idx = vocab['sent_split'] else: - self.sent_split_idx = vocab['special_token_idx']['end'] + self.sent_split_idx = vocab['end'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['pad_word'] self.tok2ind = vocab['tok2ind'] self.ind2tok = vocab['ind2tok'] diff --git a/crslab/data/dataloader/kbrd.py b/crslab/data/dataloader/kbrd.py index ee90381..182044d 100644 --- a/crslab/data/dataloader/kbrd.py +++ b/crslab/data/dataloader/kbrd.py @@ -8,11 +8,11 @@ # @Email : wxl1999@foxmail.com import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, - padded_tensor, truncate) from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt + class KBRDDataLoader(BaseDataLoader): """Dataloader for model KBRD. @@ -45,10 +45,10 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.pad_entity_idx = vocab['pad_entity'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) @@ -58,20 +58,27 @@ def rec_process_fn(self): for conv_dict in tqdm(self.dataset): if conv_dict['role'] == 'Recommender': for movie in conv_dict['items']: - augment_conv_dict = {'context_entities': conv_dict['context_entities'], 'item': movie} + augment_conv_dict = { + 'role': conv_dict['role'], + 'context': conv_dict['context'], + 'item': movie + } augment_dataset.append(augment_conv_dict) return augment_dataset def rec_batchify(self, batch): - batch_context_entities = [] + batch_role = [] + batch_context = [] batch_movies = [] for conv_dict in batch: - batch_context_entities.append(conv_dict['context_entities']) + batch_role.append(conv_dict['role']) + batch_context.append(conv_dict['context']) batch_movies.append(conv_dict['item']) return { - "context_entities": batch_context_entities, - "item": torch.tensor(batch_movies, dtype=torch.long) + "role": batch_role, + 'context': batch_context, + "item": batch_movies } def conv_process_fn(self, *args, **kwargs): diff --git a/crslab/data/dataloader/kgsf.py b/crslab/data/dataloader/kgsf.py index c43e933..6bbcac4 100644 --- a/crslab/data/dataloader/kgsf.py +++ b/crslab/data/dataloader/kgsf.py @@ -10,11 +10,11 @@ from copy import deepcopy import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import (add_start_end_token_idx, get_onehot, - merge_utt, padded_tensor, truncate) from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, get_onehot, truncate, merge_utt + class KGSFDataLoader(BaseDataLoader): """Dataloader for model KGSF. @@ -52,11 +52,11 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['pad_word'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) diff --git a/crslab/data/dataloader/ntrd.py b/crslab/data/dataloader/ntrd.py index 603be05..bbf1e80 100644 --- a/crslab/data/dataloader/ntrd.py +++ b/crslab/data/dataloader/ntrd.py @@ -5,12 +5,11 @@ from copy import deepcopy import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import (add_start_end_token_idx, get_onehot, - merge_utt, merge_utt_replace, - padded_tensor, truncate) from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, merge_utt_replace, padded_tensor, get_onehot, truncate, merge_utt + class NTRDDataLoader(BaseDataLoader): def __init__(self, opt, dataset, vocab): @@ -24,11 +23,11 @@ def __init__(self, opt, dataset, vocab): """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['pad_word'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) @@ -114,4 +113,4 @@ def conv_batchify(self, batch): padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False)) def policy_batchify(self, *args, **kwargs): - pass + pass \ No newline at end of file diff --git a/crslab/data/dataloader/redial.py b/crslab/data/dataloader/redial.py index 84c47b6..6cd1289 100644 --- a/crslab/data/dataloader/redial.py +++ b/crslab/data/dataloader/redial.py @@ -11,10 +11,11 @@ from copy import copy import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import get_onehot, padded_tensor, truncate from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import padded_tensor, get_onehot, truncate + movie_pattern = re.compile(r'^@\d{5,6}$') @@ -54,10 +55,10 @@ def __init__(self, opt, dataset, vocab): super().__init__(opt, dataset) self.ind2tok = vocab['ind2tok'] self.n_entity = vocab['n_entity'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.unk_token_idx = vocab['special_token_idx']['unk'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.unk_token_idx = vocab['unk'] self.item_token_idx = vocab['vocab_size'] self.conversation_truncate = self.opt.get('conversation_truncate', None) self.utterance_truncate = self.opt.get('utterance_truncate', None) diff --git a/crslab/data/dataloader/tgredial.py b/crslab/data/dataloader/tgredial.py index e8457ec..bef2ca7 100644 --- a/crslab/data/dataloader/tgredial.py +++ b/crslab/data/dataloader/tgredial.py @@ -11,11 +11,11 @@ from copy import deepcopy import torch -from crslab.data.dataloader.base import BaseDataLoader -from crslab.data.dataloader.utils import (add_start_end_token_idx, merge_utt, - padded_tensor, truncate) from tqdm import tqdm +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt + class TGReDialDataLoader(BaseDataLoader): """Dataloader for model TGReDial. @@ -65,26 +65,26 @@ def __init__(self, opt, dataset, vocab): self.n_entity = vocab['n_entity'] self.item_size = self.n_entity - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.unk_token_idx = vocab['special_token_idx']['unk'] - self.conv_bos_id = vocab['special_token_idx']['start'] - self.cls_id = vocab['special_token_idx']['start'] - self.sep_id = vocab['special_token_idx']['end'] - if 'sent_split' in vocab['special_token_idx']: - self.sent_split_idx = vocab['special_token_idx']['sent_split'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.unk_token_idx = vocab['unk'] + self.conv_bos_id = vocab['start'] + self.cls_id = vocab['start'] + self.sep_id = vocab['end'] + if 'sent_split' in vocab: + self.sent_split_idx = vocab['sent_split'] else: - self.sent_split_idx = vocab['special_token_idx']['end'] - if 'word_split' in vocab['special_token_idx']: - self.word_split_idx = vocab['special_token_idx']['word_split'] + self.sent_split_idx = vocab['end'] + if 'word_split' in vocab: + self.word_split_idx = vocab['word_split'] else: - self.word_split_idx = vocab['special_token_idx']['end'] + self.word_split_idx = vocab['end'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] - if 'pad_topic' in vocab['special_token_idx']: - self.pad_topic_idx = vocab['special_token_idx']['pad_topic'] + self.pad_entity_idx = vocab['pad_entity'] + self.pad_word_idx = vocab['pad_word'] + if 'pad_topic' in vocab: + self.pad_topic_idx = vocab['pad_topic'] self.tok2ind = vocab['tok2ind'] self.ind2tok = vocab['ind2tok'] diff --git a/crslab/data/dataset/base.py b/crslab/data/dataset/base.py index 080caf7..a04c672 100644 --- a/crslab/data/dataset/base.py +++ b/crslab/data/dataset/base.py @@ -12,9 +12,12 @@ from abc import ABC, abstractmethod import numpy as np -from crslab.download import build from loguru import logger +import json + +from crslab.download import build + class BaseDataset(ABC): """Abstract class of dataset @@ -51,7 +54,7 @@ def __init__(self, opt, dpath, resource, restore=False, save=False): test_data) embedding = opt.get('embedding', None) if embedding: - self.side_data["embedding"] = self.vocab['word2vec'] + self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding)) logger.debug(f'[Load pretrained embedding {embedding}]') logger.info('[Finish data preprocess]') else: @@ -133,6 +136,14 @@ def _data_preprocess(self, train_data, valid_data, test_data): """ pass + + @abstractmethod + def get_attr_list(self): + """ + Returns: + (list of str): attributes + """ + pass def _load_from_restore(self, file_name="all_data.pkl"): """Restore saved dataset. diff --git a/crslab/data/dataset/durecdial/durecdial.py b/crslab/data/dataset/durecdial/durecdial.py index 16640a6..9b0db5a 100644 --- a/crslab/data/dataset/durecdial/durecdial.py +++ b/crslab/data/dataset/durecdial/durecdial.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - r""" DuRecDial ========= @@ -27,14 +22,11 @@ import os from copy import copy -import gensim -import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -63,7 +55,7 @@ class DuRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """ Args: @@ -73,26 +65,14 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - self.tokenize = tokenize - self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'durecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() + train_data, valid_data, test_data = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,73 +85,40 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx, } - + vocab.update(self.special_token_idx) + return train_data, valid_data, test_data, vocab def _load_raw_data(self): with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + return train_data, valid_data, test_data def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # entity kg with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f: self.entity2id = json.load(f) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join( - self.dpath, 'entity_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") @@ -181,8 +128,7 @@ def _load_other_data(self): self.word2id = json.load(f) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join( - self.dpath, 'hownet_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]") @@ -198,8 +144,7 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id( - conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -211,14 +156,10 @@ def _convert_to_id(self, conversation): for utt in conversation['dialog']: assert utt['role'] != last_role, print(utt) - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - item_ids = [self.entity2id[movie] - for movie in utt['item'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] augmented_convs.append({ "role": utt["role"], @@ -321,93 +262,3 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - cnt = 0 - tok2ind = {} - if self.tokenize == 'nltk' or self.tokenize == 'jieba': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - for i in tqdm(processed_train_data): - dialog = i['dialog'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['dialog']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - for item in dialog['item']: - word_list = self.tokenizer.tokenize(item) - match_list += word_list - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - match_list = list(set(match_list)) - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['dialog']: - text = dialog['text'] - corpus.append(text) - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding diff --git a/crslab/data/dataset/durecdial/resources.py b/crslab/data/dataset/durecdial/resources.py index c226269..6bb858f 100644 --- a/crslab/data/dataset/durecdial/resources.py +++ b/crslab/data/dataset/durecdial/resources.py @@ -8,20 +8,63 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'jieba': { + 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/ERN4GhkC-fBLk1gRKZeHgo4BnQglDxv7VTVmbqgPdL108A?download=1', - 'durecdial.zip', - '9b781f82a9192e96a1e7a9f7501edc930e0e13c0732faf8e3964360a6d5c6ca5', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', + 'durecdial_jieba.zip', + 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', + 'durecdial_bert.zip', + '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', + 'durecdial_gpt2.zip', + 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } } diff --git a/crslab/data/dataset/gorecdial/gorecdial.py b/crslab/data/dataset/gorecdial/gorecdial.py index b9ea152..fe3e288 100644 --- a/crslab/data/dataset/gorecdial/gorecdial.py +++ b/crslab/data/dataset/gorecdial/gorecdial.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - r""" GoRecDial ========= @@ -27,14 +22,11 @@ import os from copy import copy -import gensim -import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -63,7 +55,7 @@ class GoRecDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -73,26 +65,14 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - self.tokenize = tokenize - self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'gorecdial') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() + train_data, valid_data, test_data = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,10 +85,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx, } + vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab @@ -116,74 +94,41 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + return train_data, valid_data, test_data def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join( - self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") # conceptnet # {concept: concept_id} - self.word2id = json.load( - open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join( - self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -199,8 +144,7 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id( - conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -212,14 +156,10 @@ def _convert_to_id(self, conversation): for utt in conversation['dialog']: assert utt['role'] != last_role - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] - for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] policy = utt['decide'] augmented_convs.append({ @@ -240,7 +180,7 @@ def _augment_and_add(self, raw_conv_dict): entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \ - conv['policy'] + conv['policy'] if len(context_tokens) > 0 and len(text_tokens) > 0: conv_dict = { 'role': conv['role'], @@ -273,8 +213,7 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load( - open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -327,102 +266,3 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - - cnt = 0 - tok2ind = {} - - if self.tokenize == 'nltk' or self.tokenize == 'jieba': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - - for i in tqdm(processed_train_data): - dialog = i['dialog'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['dialog']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - for movie in dialog['movies']: - word_list = self.tokenizer.tokenize(movie) - match_list += word_list - - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - - match_list = list(set(match_list)) - - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['dialog']: - text = dialog['text'] - corpus.append(text) - - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding diff --git a/crslab/data/dataset/gorecdial/resources.py b/crslab/data/dataset/gorecdial/resources.py index 57c8614..030202e 100644 --- a/crslab/data/dataset/gorecdial/resources.py +++ b/crslab/data/dataset/gorecdial/resources.py @@ -8,20 +8,61 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'nltk': { + 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EYmobnFBox1LnGKGW4TMCk8BW6rnjdAZNVsNo8uJ8ZsJLg?download=1', - 'gorecdial.zip', - '66035bf24862535a072cc6778a3affd541ae0a4aa1fe31455d4fb063b301f087', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESIqjwAg0ItAu7WGfukIt3cBXjzi7AZ9L_lcbFT1aS1qYQ?download=1', + 'gorecdial_nltk.zip', + '58cd368f8f83c0c8555becc314a0017990545f71aefb7e93a52581c97d1b8e9b', ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'bert': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed1HT8gzvRpDosVT83BEj5QBnzKpjR3Zbf5u49yyWP-k6Q?download=1', + 'gorecdial_bert.zip', + '4fa10c3fe8ba538af0f393c99892739fcb376d832616aa7028334c594b3fec10' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + } + }, + 'gpt2': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EUJOHmX8v79DkZMq0x5r9d4B0UJlfw85v-VdciwKfAhpng?download=1', + 'gorecdial_gpt2.zip', + '44a15637e014b2e6628102ff654e1aef7ec1cbfa34b7ada1a03f294f72ddd4b1' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } diff --git a/crslab/data/dataset/inspired/inspired.py b/crslab/data/dataset/inspired/inspired.py index 0be8da5..c44f0d7 100644 --- a/crslab/data/dataset/inspired/inspired.py +++ b/crslab/data/dataset/inspired/inspired.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - r""" Inspired ======== @@ -27,14 +22,11 @@ import os from copy import copy -import gensim -import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -63,7 +55,7 @@ class InspiredDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -73,26 +65,14 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - self.tokenize = tokenize - self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'inspired') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() + train_data, valid_data, test_data = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -105,10 +85,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx, } + vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab @@ -116,63 +94,33 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + return train_data, valid_data, test_data def _load_vocab(self): + with open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8') as f: + self.tok2ind = json.load(f) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f: self.entity2id = json.load(f) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join( - self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join(self.dpath, 'dbpedia_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") @@ -182,8 +130,7 @@ def _load_other_data(self): self.word2id = json.load(f) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join( - self.dpath, 'concept_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -199,8 +146,7 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id( - conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -210,14 +156,10 @@ def _convert_to_id(self, conversation): augmented_convs = [] last_role = None for utt in conversation['dialog']: - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] - for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] if utt["role"] == last_role: augmented_convs[-1]["text"] += text_token_ids @@ -326,112 +268,3 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - - cnt = 0 - tok2ind = {} - - if self.tokenize == 'nltk' or self.tokenize == 'jieba': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - - for i in tqdm(processed_train_data): - dialog = i['dialog'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['dialog']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - for movie in dialog['movies']: - word_list = self.tokenizer.tokenize(movie) - match_list += word_list - - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - - for genre in dialog['genre']: - word_list = self.tokenizer.tokenize(genre) - match_list += word_list - - for people in dialog['people']: - word_list = self.tokenizer.tokenize(people) - match_list += word_list - - match_list = list(set(match_list)) - - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['dialog']: - text = dialog['text'] - corpus.append(text) - - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding diff --git a/crslab/data/dataset/inspired/resources.py b/crslab/data/dataset/inspired/resources.py index 38fb3be..504a760 100644 --- a/crslab/data/dataset/inspired/resources.py +++ b/crslab/data/dataset/inspired/resources.py @@ -8,20 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'nltk': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdDgeChYguFLvz8hmkNdRhABmQF-LBfYtdb7rcdnB3kUgA?download=1', + 'inspired_nltk.zip', + '776cadc7585abdbca2738addae40488826c82de3cfd4c2dc13dcdd63aefdc5c4', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfBfyxLideBDsupMWb2tANgB6WxySTPQW11uM1F4UV5mTQ?download=1', + 'inspired_bert.zip', + '9affea30978a6cd48b8038dddaa36f4cb4d8491cf8ae2de44a6d3dde2651f29c' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXv8zwgCOY1EstHNjjs194cBqMIrdg4yxcyNsHKltTzyig?download=1', - 'inspired.zip', - '1085c2ab31fd7691f24531f9beef9016b0f3137366495784569a63f82ddd95ed', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVwbqtjDReZHnvb_l9TxaaIBAC63BjbqkN5ZKb24Mhsm_A?download=1', + 'inspired_gpt2.zip', + '261ad7e5325258d5cb8ffef0751925a58270fb6d9f17490f8552f6b86ef1eed2' ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, } } diff --git a/crslab/data/dataset/opendialkg/opendialkg.py b/crslab/data/dataset/opendialkg/opendialkg.py index bfe689f..102fd8b 100644 --- a/crslab/data/dataset/opendialkg/opendialkg.py +++ b/crslab/data/dataset/opendialkg/opendialkg.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - r""" OpenDialKG ========== @@ -27,16 +22,12 @@ import os from collections import defaultdict from copy import copy -from http.client import NotConnected -import gensim -import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -65,7 +56,7 @@ class OpenDialKGDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -75,26 +66,14 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - self.tokenize = tokenize - self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'opendialkg') super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() + train_data, valid_data, test_data = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -107,10 +86,8 @@ def _load_data(self): 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx } + vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab @@ -118,74 +95,41 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + return train_data, valid_data, test_data def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # opendialkg self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join( - self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]") # conceptnet # {concept: concept_id} - self.word2id = json.load( - open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {concept \t relation\t concept} - self.word_kg = open(os.path.join( - self.dpath, 'concept_subkg.txt'), encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") @@ -201,8 +145,7 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id( - conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -212,14 +155,10 @@ def _convert_to_id(self, conversation): augmented_convs = [] last_role = None for utt in conversation['dialog']: - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - item_ids = [self.entity2id[movie] - for movie in utt['item'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] if utt["role"] == last_role: augmented_convs[-1]["text"] += text_token_ids @@ -274,8 +213,7 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - item_entity_ids = json.load( - open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) + item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load item entity ids]') side_data = { @@ -300,8 +238,7 @@ def _entity_kg_process(self): if e1 != e0: edge_list.append((e1, e1, 'SELF_LOOP')) - relation_cnt, relation2id, edges, entities = defaultdict( - int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -334,104 +271,7 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - - cnt = 0 - tok2ind = {} - - if self.tokenize == 'nltk' or self.tokenize == 'jieba': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - - for i in tqdm(processed_train_data): - dialog = i['dialog'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['dialog']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - - for item in dialog['item']: - word_list = self.tokenizer.tokenize(item) - match_list += word_list - - match_list = list(set(match_list)) - - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['dialog']: - text = dialog['text'] - corpus.append(text) - - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding + + def get_attr_list(self): + attr_list = ['genre', 'actor', 'director', 'writer'] + return attr_list \ No newline at end of file diff --git a/crslab/data/dataset/opendialkg/resources.py b/crslab/data/dataset/opendialkg/resources.py index 2fc13db..9f7fb62 100644 --- a/crslab/data/dataset/opendialkg/resources.py +++ b/crslab/data/dataset/opendialkg/resources.py @@ -8,20 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'nltk': { + 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUknGWqDp15OoI2U7DE6EHkBoZVaK273DJfxCdXuluqQjA?download=1', - 'opendialkg.zip', - '73c2632ddf27d15a9f89cd288dae4e200a6a7a2487edc303f881077bc6884671', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', + 'opendialkg_nltk.zip', + '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f' ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', + 'opendialkg_bert.zip', + '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', + 'opendialkg_gpt2.zip', + 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index ec4d807..6ceb036 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -7,10 +7,10 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com +# UPDATE: +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com r""" ReDial @@ -24,18 +24,16 @@ """ import json +import re import os from collections import defaultdict from copy import copy -import gensim -import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -64,7 +62,7 @@ class ReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -74,119 +72,92 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] - self.tokenize = tokenize - self.tokenizer = crs_tokenizer - dpath = os.path.join(DATASET_PATH, "redial") + dpath = os.path.join(DATASET_PATH, opt['dataset']) super().__init__(opt, dpath, resource, restore, save) def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() - self._load_vocab() + train_data, valid_data, test_data = self._load_raw_data() + # self._load_vocab() self._load_other_data() vocab = { - 'tok2ind': self.tok2ind, - 'ind2tok': self.ind2tok, + # 'tok2ind': self.tok2ind, + # 'ind2tok': self.ind2tok, 'entity2id': self.entity2id, 'id2entity': self.id2entity, 'word2id': self.word2id, - 'vocab_size': len(self.tok2ind), + # 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx } + vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab def _load_raw_data(self): # load train/valid/test data - with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: - train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - - with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: - valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - - with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: - test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") - - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + train_data = [] + valid_data = [] + test_data = [] + + with open(os.path.join(self.dpath, 'train_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + train_data.append(data) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + valid_data.append(data) + valid_data.append(data) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.jsonl'), 'r', encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + test_data.append(data) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") def _load_other_data(self): # dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = json.load( - open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8')) + self.entity_kg = json.load(open(os.path.join(self.dpath, 'kg.json'), 'r', encoding='utf-8')) logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") # conceptNet # {concept: concept_id} - self.word2id = json.load( - open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} - self.word_kg = open(os.path.join( - self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]") + with open(os.path.join(self.dpath, 'id2info.json'), 'r', encoding='utf-8') as f: + id2info = json.load(f) + self.id2name = {} + for id, info in id2info.items(): + self.id2name[id] = info['name'] + def _data_preprocess(self, train_data, valid_data, test_data): processed_train_data = self._raw_data_process(train_data) @@ -200,54 +171,69 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._merge_conv_data( - conversation["dialog"]) for conversation in tqdm(raw_data)] + augmented_convs = [self._merge_conv_data(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) return augmented_conv_dicts def _merge_conv_data(self, dialog): + movie_pattern = re.compile(r'@\d+') augmented_convs = [] last_role = None - for utt in dialog: - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] - for movie in utt['movies'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] - - if utt["role"] == last_role: - augmented_convs[-1]["text"] += text_token_ids - augmented_convs[-1]["movie"] += movie_ids + user_id = dialog['initiatorWorkerId'] + recommender_id = dialog['respondentWorkerId'] + conversation_id = dialog['conversationId'] + + for utt in dialog["messages"]: + # text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + turn_id = utt['turn_id'] + dialog_turn_id = str(conversation_id) + '_' + str(turn_id) + + text = utt['text'] + for pattern in re.findall(movie_pattern, text): + if pattern.strip('@') in self.id2name: + text = text.replace(pattern, self.id2name[pattern.strip('@')]) + movie_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + + role_id = utt['senderWorkerId'] + if role_id == recommender_id: + role = 'Recommender' + elif role_id == user_id: + role = 'User' + + if role == last_role: + augmented_convs[-1]["text"] += text + augmented_convs[-1]["item"] += movie_ids augmented_convs[-1]["entity"] += entity_ids augmented_convs[-1]["word"] += word_ids else: augmented_convs.append({ - "role": utt["role"], - "text": text_token_ids, + "dialog_id": dialog_turn_id, + "role": role, + "text": text, "entity": entity_ids, - "movie": movie_ids, + "item": movie_ids, "word": word_ids }) - last_role = utt["role"] + last_role = role return augmented_convs def _augment_and_add(self, raw_conv_dict): augmented_conv_dicts = [] - context_tokens, context_entities, context_words, context_items = [], [], [], [] + context, context_entities, context_words, context_items = [], [], [], [] entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): - text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"] - if len(context_tokens) > 0: + text, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context) > 0: conv_dict = { + "dialog_id": conv['dialog_id'], "role": conv['role'], - "context_tokens": copy(context_tokens), - "response": text_tokens, + "context": copy(context), + "response": text, "context_entities": copy(context_entities), "context_words": copy(context_words), "context_items": copy(context_items), @@ -255,7 +241,7 @@ def _augment_and_add(self, raw_conv_dict): } augmented_conv_dicts.append(conv_dict) - context_tokens.append(text_tokens) + context.append(text) context_items += movies for entity in entities + movies: if entity not in entity_set: @@ -273,8 +259,7 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load( - open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -292,13 +277,10 @@ def _entity_kg_process(self, SELF_LOOP_ID=185): edge_list.append((entity, entity, SELF_LOOP_ID)) # add self loop for tail_and_relation in self.entity_kg[str(entity)]: if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID: - edge_list.append( - (entity, tail_and_relation[1], tail_and_relation[0])) - edge_list.append( - (tail_and_relation[1], entity, tail_and_relation[0])) + edge_list.append((entity, tail_and_relation[1], tail_and_relation[0])) + edge_list.append((tail_and_relation[1], entity, tail_and_relation[0])) - relation_cnt, relation2id, edges, entities = defaultdict( - int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -330,103 +312,36 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - for one in each['dialog']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['dialog'] = each_data - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - - cnt = 0 - tok2ind = {} - - if self.tokenize == 'nltk' or self.tokenize == 'jieba': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - - for i in tqdm(processed_train_data): - dialog = i['dialog'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['dialog']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - for movie in dialog['movies']: - word_list = self.tokenizer.tokenize(movie) - match_list += word_list - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - - match_list = list(set(match_list)) - - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['dialog']: - text = dialog['text'] - corpus.append(text) - - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding + + def get_attr_list(self): + attr_list = ['genre', 'star', 'director'] + return attr_list + + def get_ask_instruction(self): + ask_instruction = '''To recommend me items that I will accept, you can choose one of the following options. +A: ask my preference for genre +B: ask my preference for actor +C: ask my preference for director +D: I can directly give recommendations +Please enter the option character. Please only response a character.''' + option2attr = { + 'A': 'genre', + 'B': 'star', + 'C': 'director', + 'D': 'recommend' + } + option2template = { + 'A': 'Which genre do you like?', + 'B': 'Which star do you like?', + 'C': 'Which director do you like?', + } + rec_instruction = 'Please give me 10 recommendations according to my preference (Format: no. title. No other things except the item list in your response). You can recommend mentioned items in our dialog.' + + ask_instruction_dict = { + 'ask_instruction': ask_instruction, + 'option2attr': option2attr, + 'option2template': option2template, + 'rec_instruction': rec_instruction + } + + return ask_instruction_dict \ No newline at end of file diff --git a/crslab/data/dataset/redial/resources.py b/crslab/data/dataset/redial/resources.py index 9809a6e..551e7ab 100644 --- a/crslab/data/dataset/redial/resources.py +++ b/crslab/data/dataset/redial/resources.py @@ -8,20 +8,59 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'nltk': { + 'version': '0.31', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ea4PEMnyyqxAl6tiAC17BcgBW8fZ6eveNKAbAU5sYt8-PQ?download=1', - 'redial.zip', - '9fcccc47095c6c8764a3f92e9ec993a2f5f635458836ac3314dcf007ad80d639', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', + 'redial_nltk.zip', + '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, }, + 'bert': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', + 'redial_bert.zip', + 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', + 'redial_gpt2.zip', + '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', + ), + 'special_token_idx': { + 'pad': -100, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } } diff --git a/crslab/data/dataset/tgredial/resources.py b/crslab/data/dataset/tgredial/resources.py index 721afae..b46e73b 100644 --- a/crslab/data/dataset/tgredial/resources.py +++ b/crslab/data/dataset/tgredial/resources.py @@ -8,20 +8,64 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { - 'resource': { - 'version': '1.0', + 'pkuseg': { + 'version': '0.3', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EUmmYbQ6BytMrQjmgRWuElMBZ2yv7v10wLzuwxHe9wxnYg?download=1', - 'tgredial.zip', - '9895809dcceffc01da932716a5dc8e113917c7680d0fdf5c79169add2ec0d3a8', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', + 'tgredial_pkuseg.zip', + '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', + 'tgredial_bert.zip', + 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', + 'tgredial_gpt2.zip', + '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } } diff --git a/crslab/data/dataset/tgredial/tgredial.py b/crslab/data/dataset/tgredial/tgredial.py index c30bbd8..e23f1c4 100644 --- a/crslab/data/dataset/tgredial/tgredial.py +++ b/crslab/data/dataset/tgredial/tgredial.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, sdzyh002@gmail -# UPDATE -# @Time : 2022/9/26 -# @Author : Xinyu Tang -# @email : txy20010310@163.com - r""" TGReDial ======== @@ -27,15 +22,12 @@ import os from collections import defaultdict from copy import copy - -import gensim import numpy as np from loguru import logger from tqdm import tqdm -from crslab.config import DATASET_PATH, MODEL_PATH +from crslab.config import DATASET_PATH from crslab.data.dataset.base import BaseDataset - from .resources import resources @@ -67,7 +59,7 @@ class TGReDialDataset(BaseDataset): """ - def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): + def __init__(self, opt, tokenize, restore=False, save=False): """Specify tokenized resource and init base dataset. Args: @@ -77,37 +69,23 @@ def __init__(self, opt, tokenize, crs_tokenizer, restore=False, save=False): save (bool): whether to save dataset after processing. Defaults to False. """ - if 'copy' in opt: - self.copy = True - else: - self.copy = False - - if 'embedding' in opt: - self.generate_embedding = True - else: - self.generate_embedding = False - - resource = resources['resource'] - self.special_token_idx = crs_tokenizer.special_token_idx + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] self.unk_token_idx = self.special_token_idx['unk'] self.pad_topic_idx = self.special_token_idx['pad_topic'] - - self.tokenize = tokenize - self.tokenizer = crs_tokenizer dpath = os.path.join(DATASET_PATH, 'tgredial') - - self.replace_token = opt.get('replace_token', None) - self.replace_token_idx = opt.get('replace_token_idx', None) + self.replace_token = opt.get('replace_token',None) + self.replace_token_idx = opt.get('replace_token_idx',None) super().__init__(opt, dpath, resource, restore, save) if self.replace_token: if self.replace_token_idx: self.side_data["embedding"][self.replace_token_idx] = self.side_data['embedding'][0] else: - self.side_data["embedding"] = np.insert(self.side_data["embedding"], len( - self.side_data["embedding"]), self.side_data['embedding'][0], axis=0) + self.side_data["embedding"] = np.insert(self.side_data["embedding"],len(self.side_data["embedding"]),self.side_data['embedding'][0],axis=0) + def _load_data(self): - train_data, valid_data, test_data, word2vec, copy_mask = self._load_raw_data() + train_data, valid_data, test_data = self._load_raw_data() self._load_vocab() self._load_other_data() @@ -123,10 +101,8 @@ def _load_data(self): 'n_topic': len(self.topic2ind) + 1, 'n_entity': self.n_entity, 'n_word': self.n_word, - 'word2vec': word2vec, - 'copy_mask': copy_mask, - 'special_token_idx': self.special_token_idx } + vocab.update(self.special_token_idx) return train_data, valid_data, test_data, vocab @@ -134,44 +110,18 @@ def _load_raw_data(self): # load train/valid/test data with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: train_data = json.load(f) - logger.debug( - f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") - # split text - processed_train_data = self.split_text(train_data) - logger.info("[Finish train data split]") - # generate tok2ind - self.tok2ind = self.generate_tok2ind(processed_train_data) - logger.info("[Finish generate train tok2ind]") - # generate word2vec - word_embedding = None - if self.generate_embedding: - word_embedding = self.generate_word2vec(processed_train_data) - logger.info('[Finish generate word2vec]') - # build copy_mask - copy_mask = None - if self.copy: - copy_mask = self.generate_copy_mask(self.tok2ind, processed_train_data) - logger.info('[Finish generate copy_mask]') - + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: valid_data = json.load(f) - logger.debug( - f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") - # split_text - processed_valid_data = self.split_text(valid_data) - logger.info("[Finish valid data split]") - + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: test_data = json.load(f) - logger.debug( - f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - # split_text - processed_test_data = self.split_text(test_data) - logger.info("[Finish test data split]") + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") - return processed_train_data, processed_valid_data, processed_test_data, word_embedding, copy_mask + return train_data, valid_data, test_data def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} # add special tokens if self.replace_token: @@ -183,61 +133,46 @@ def _load_vocab(self): else: self.ind2tok[len(self.tok2ind)] = self.replace_token self.tok2ind[self.replace_token] = len(self.tok2ind) - self.special_token_idx[self.replace_token] = len( - self.tok2ind)-1 - logger.debug( - f"[Load vocab from token2id]") - logger.debug( - f"[The size of token2index dictionary is {len(self.tok2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2tok)}]") + self.special_token_idx[self.replace_token] = len(self.tok2ind)-1 + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") - self.topic2ind = json.load( - open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8')) + self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8')) self.ind2topic = {idx: word for word, idx in self.topic2ind.items()} - logger.debug( - f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]") - logger.debug( - f"[The size of token2index dictionary is {len(self.topic2ind)}]") - logger.debug( - f"[The size of index2token dictionary is {len(self.ind2topic)}]") + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.topic2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2topic)}]") def _load_other_data(self): # cn-dbpedia self.entity2id = json.load( open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} - self.id2entity = {idx: entity for entity, - idx in self.entity2id.items()} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} self.n_entity = max(self.entity2id.values()) + 1 # {head_entity_id: [(relation_id, tail_entity_id)]} - self.entity_kg = open(os.path.join( - self.dpath, 'cn-dbpedia.txt'), encoding='utf-8') + self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8') logger.debug( f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]") # hownet # {concept: concept_id} - self.word2id = json.load( - open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} - self.word_kg = open(os.path.join( - self.dpath, 'hownet.txt'), encoding='utf-8') + self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8') logger.debug( f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]") # user interaction history dictionary - self.conv2history = json.load( - open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8')) - logger.debug( - f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]") + self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8')) + logger.debug(f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]") # user profile - self.user2profile = json.load( - open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8')) - logger.debug( - f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}") + self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8')) + logger.debug(f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}") + def _data_preprocess(self, train_data, valid_data, test_data): processed_train_data = self._raw_data_process(train_data) @@ -251,8 +186,7 @@ def _data_preprocess(self, train_data, valid_data, test_data): return processed_train_data, processed_valid_data, processed_test_data, processed_side_data def _raw_data_process(self, raw_data): - augmented_convs = [self._convert_to_id( - conversation) for conversation in tqdm(raw_data)] + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] augmented_conv_dicts = [] for conv in tqdm(augmented_convs): augmented_conv_dicts.extend(self._augment_and_add(conv)) @@ -266,19 +200,14 @@ def _convert_to_id(self, conversation): # change movies into slots if self.replace_token: if len(utt['movie']) != 0: - while '怊' in utt['text']: + while '怊' in utt['text'] : begin = utt['text'].index("怊") end = utt['text'].index("怋") - utt['text'] = utt['text'][:begin] + \ - [self.replace_token] + utt['text'][end+1:] - text_token_ids = [self.tok2ind.get( - word, self.unk_token_idx) for word in utt["text"]] - movie_ids = [self.entity2id[movie] - for movie in utt['movie'] if movie in self.entity2id] - entity_ids = [self.entity2id[entity] - for entity in utt['entity'] if entity in self.entity2id] - word_ids = [self.word2id[word] - for word in utt['word'] if word in self.word2id] + utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] policy = [] for action, kw in zip(utt['target'][1::2], utt['target'][2::2]): if kw is None or action == 'ęŽØ荐ē”µå½±': @@ -287,15 +216,12 @@ def _convert_to_id(self, conversation): kw = [kw] kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw] policy.append([action, kw]) - final_kws = [ - self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]] + final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]] final = [utt['final'][0], final_kws] - conv_utt_id = str( - conversation['conv_id']) + '/' + str(utt['local_id']) + conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id']) interaction_history = self.conv2history.get(conv_utt_id, []) user_profile = self.user2profile[conversation['user_id']] - user_profile = [[self.tok2ind.get( - token, self.unk_token_idx) for token in sent] for sent in user_profile] + user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile] augmented_convs.append({ "role": utt["role"], @@ -318,11 +244,11 @@ def _augment_and_add(self, raw_conv_dict): entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \ - conv['policy'] - if self.replace_token is not None: + conv['policy'] + if self.replace_token is not None: if text_tokens.count(30000) != len(movies): - continue # the number of slots doesn't equal to the number of movies - + continue # the number of slots doesn't equal to the number of movies + if len(context_tokens) > 0: conv_dict = { 'role': conv['role'], @@ -359,8 +285,7 @@ def _side_data_process(self): logger.debug("[Finish entity KG process]") processed_word_kg = self._word_kg_process() logger.debug("[Finish word KG process]") - movie_entity_ids = json.load( - open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) logger.debug('[Load movie entity ids]') side_data = { @@ -383,8 +308,7 @@ def _entity_kg_process(self): if e1 != e0: edge_list.append((e1, e1, 'SELF_LOOP')) - relation_cnt, relation2id, edges, entities = defaultdict( - int), dict(), set(), set() + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() for h, t, r in edge_list: relation_cnt[r] += 1 for h, t, r in edge_list: @@ -416,107 +340,3 @@ def _word_kg_process(self): 'edge': list(edges), 'entity': list(entities) } - - def split_text(self, data): - - all_data = [] - for each in tqdm(data): - each_dict = {} - each_data = [] - each_dict['conv_id'] = each['conv_id'] - for one in each['messages']: - text_str = one['text'] - text_list = self.tokenizer.tokenize(text_str) - one['text'] = text_list - each_data.append(one) - each_dict['messages'] = each_data - each_dict['user_id'] = each['user_id'] - all_data.append(each_dict) - - return all_data - - def generate_tok2ind(self, processed_train_data): - - cnt = 0 - tok2ind = {} - - if self.tokenize == 'nltk' or self.tokenize == 'jieba' or self.tokenize == 'pkuseg': - tok2ind['__pad__'] = cnt - cnt += 1 - tok2ind['__start__'] = cnt - cnt += 1 - tok2ind['__end__'] = cnt - cnt += 1 - tok2ind['__unk__'] = cnt - cnt += 1 - elif self.tokenize == 'bert': - tok2ind['[PAD]'] = cnt - cnt += 1 - - for i in tqdm(processed_train_data): - dialog = i['messages'] - for each_dialog in dialog: - text = each_dialog['text'] - for each_word in text: - if each_word not in tok2ind: - tok2ind[each_word] = cnt - cnt += 1 - - if self.tokenize == 'nltk': - tok2ind['_split_'] = cnt - cnt += 1 - - return tok2ind - - def generate_copy_mask(self, tok2ind, processed_train_data): - - copy_mask = np.zeros((len(tok2ind)), dtype=bool) - for each_data in tqdm(processed_train_data): - for dialog in each_data['messages']: - match_list = [] - text = dialog['text'] - for word in dialog['word']: - word_list = self.tokenizer.tokenize(word) - match_list += word_list - - for movie in dialog['movie']: - word_list = self.tokenizer.tokenize(movie) - match_list += word_list - - for entity in dialog['entity']: - word_list = self.tokenizer.tokenize(entity) - match_list += word_list - - match_list = list(set(match_list)) - - for each_word in text: - if each_word in match_list: - token_id = tok2ind[each_word] - copy_mask[token_id] = True - - return copy_mask - - def generate_word2vec(self, processed_train_data): - - corpus = [] - for each_data in processed_train_data: - for dialog in each_data['messages']: - text = dialog['text'] - corpus.append(text) - - model = gensim.models.word2vec.Word2Vec( - corpus, vector_size=300, min_count=1) - - if self.tokenize == 'nltk': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] + [[0] * 300] - - elif self.tokenize == 'jieba' or self.tokenize == 'pkuseg': - word2index = {word: i + 4 for i, - word in enumerate(model.wv.index_to_key)} - word2embedding = [[0] * 300] * 4 + [model.wv[word] - for word in word2index] - - return word2embedding diff --git a/crslab/data/dataset/tokenizer/__init__.py b/crslab/data/dataset/tokenizer/__init__.py deleted file mode 100644 index d5c67be..0000000 --- a/crslab/data/dataset/tokenizer/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .base import BaseTokenizer -from .bert import BertTokenizer -from .gpt2 import Gpt2Tokenizer -from .jieba import JiebaTokenizer -from .nltk import NltkTokenizer -from .pkuseg import PkusegTokenizer diff --git a/crslab/data/dataset/tokenizer/base.py b/crslab/data/dataset/tokenizer/base.py deleted file mode 100644 index b5966bb..0000000 --- a/crslab/data/dataset/tokenizer/base.py +++ /dev/null @@ -1,14 +0,0 @@ -# @Time : 2022/9/30 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -class BaseTokenizer: - - def __init__(self, path=None) -> None: - pass - - def tokenize(self, text): - ''' - split token - ''' - pass diff --git a/crslab/data/dataset/tokenizer/bert.py b/crslab/data/dataset/tokenizer/bert.py deleted file mode 100644 index 73f268a..0000000 --- a/crslab/data/dataset/tokenizer/bert.py +++ /dev/null @@ -1,28 +0,0 @@ -# @Time : 2022/9/30 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -from transformers import AutoTokenizer - -from crslab.data.dataset.tokenizer.base import BaseTokenizer - - -class BertTokenizer(BaseTokenizer): - - def __init__(self, path=None) -> None: - self.special_token_idx = { - 'pad': 0, - 'start': 101, - 'end': 102, - 'unk': 100, - 'sent_split': 2, - 'word_split': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - self.my_tokenizer = AutoTokenizer.from_pretrained(path) - super().__init__(path) - - def tokenize(self, text): - return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/gpt2.py b/crslab/data/dataset/tokenizer/gpt2.py deleted file mode 100644 index 196d238..0000000 --- a/crslab/data/dataset/tokenizer/gpt2.py +++ /dev/null @@ -1,30 +0,0 @@ -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -from transformers import AutoTokenizer - -from crslab.data.dataset.tokenizer.base import BaseTokenizer - - -class Gpt2Tokenizer(BaseTokenizer): - - def __init__(self, path=None) -> None: - self.special_token_idx = { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'cls': 101, - 'sep': 102, - 'sent_split': 4, - 'word_split': 5, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - } - self.my_tokenizer = AutoTokenizer.from_pretrained(path) - super().__init__(path) - - def tokenize(self, text): - return self.my_tokenizer.tokenize(text) diff --git a/crslab/data/dataset/tokenizer/jieba.py b/crslab/data/dataset/tokenizer/jieba.py deleted file mode 100644 index 8354098..0000000 --- a/crslab/data/dataset/tokenizer/jieba.py +++ /dev/null @@ -1,26 +0,0 @@ -# @Time : 2022/9/30 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -import jieba - -from crslab.data.dataset.tokenizer.base import BaseTokenizer - - -class JiebaTokenizer(BaseTokenizer): - - def __init__(self, path=None) -> None: - self.special_token_idx = { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - } - super().__init__(path) - - def tokenize(self, text): - split_text = jieba.cut(text) - text_list = ' '.join(split_text).split() - return text_list diff --git a/crslab/data/dataset/tokenizer/nltk.py b/crslab/data/dataset/tokenizer/nltk.py deleted file mode 100644 index 01ff902..0000000 --- a/crslab/data/dataset/tokenizer/nltk.py +++ /dev/null @@ -1,27 +0,0 @@ -# @Time : 2022/9/30 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -import nltk -from nltk import word_tokenize - -from crslab.data.dataset.tokenizer.base import BaseTokenizer - - -class NltkTokenizer(BaseTokenizer): - - def __init__(self, path=None) -> None: - self.special_token_idx = { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0, - } - super().__init__(path) - - def tokenize(self, text): - nltk.download('punkt') - return word_tokenize(text) diff --git a/crslab/data/dataset/tokenizer/pkuseg.py b/crslab/data/dataset/tokenizer/pkuseg.py deleted file mode 100644 index c362f5f..0000000 --- a/crslab/data/dataset/tokenizer/pkuseg.py +++ /dev/null @@ -1,26 +0,0 @@ -# @Time : 2022/9/30 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - -import pkuseg - -from crslab.data.dataset.tokenizer.base import BaseTokenizer - - -class PkusegTokenizer(BaseTokenizer): - - def __init__(self, path=None) -> None: - self.PkusegTokenizerr = pkuseg.pkuseg() - self.special_token_idx = { - 'pad': 0, - 'start': 1, - 'end': 2, - 'unk': 3, - 'pad_entity': 0, - 'pad_word': 0, - 'pad_topic': 0 - } - super().__init__(path) - - def tokenize(self, text): - return self.PkusegTokenizerr.cut(text) diff --git a/crslab/evaluator/ask.py b/crslab/evaluator/ask.py new file mode 100644 index 0000000..c6097ea --- /dev/null +++ b/crslab/evaluator/ask.py @@ -0,0 +1,297 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import argparse +import copy +import json +import os +import re +import random +import time +import typing +import warnings +import tiktoken + +import numpy as np +import openai +import nltk + +from copy import copy +from tqdm import tqdm +from loguru import logger +from thefuzz import fuzz +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import get_entity, get_instruction + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +class Ask(): + + def __init__(self, turn_num, crs_model, dataset, ask_instruction_dict) -> None: + self.turn_num = turn_num + self.crs_model = crs_model + self.ask_instruction_dict = ask_instruction_dict + self.dataset_path = os.path.join(DATASET_PATH, dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, dataset, 'embed') + item_emb_list = [] + id2item_id = [] + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entity = {} + for k, v in self.entity2id.items(): + self.id2entity[int(v)] = k + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + self.entityid2id = {} + for id, entityid in self.id2entityid.items(): + self.entityid2id[entityid] = id + + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + def ask(self, batch, turn_num): + + ask_instruction = self.ask_instruction_dict['ask_instruction'] + option2attr = self.ask_instruction_dict['option2attr'] + option2template = self.ask_instruction_dict['option2template'] + rec_instruction = self.ask_instruction_dict['rec_instruction'] + recommendation_template = "I would recommend the following items:\n{}" + + contexts_batch = batch['context'] + items_batch = batch['item'] + entity_batch = batch['entity'] + + for context, items, entities in zip(contexts_batch, items_batch, entity_batch): + + context_list = [] + + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + role_str = 'user' + else: + role_str = 'assistant' + context_list.append({ + 'role': role_str, + 'content': text + }) + + rec_success = False + option2index = { + 'A': 0, + 'B': 1, + 'C': 2, + 'D': 3, + 'E': 4 + } + + options = list(option2attr.keys()) + state = [0 for _ in range(len(options))] + + for i in range(0, turn_num): + # seeker + + context_list.append({ + 'role': 'user', + 'content': ask_instruction + }) + + context.append(ask_instruction) + batch['context'] = [copy(context)] + + # recommender + # choose option + # options (list of str): available options, generate one of them + gen_inputs, recommender_text = self.crs_model.converse(batch) + recommender_choose = self.crs_model.choose(gen_inputs, options, state, batch) + selected_option = recommender_choose + + if selected_option == options[-1]: # choose to rec + # recommender + _, item_rank_arr = self.crs_model.recommend(batch) + pred_items = item_rank_arr[0] + + rec_items_str = '' + for j, rec_item in enumerate(pred_items): + rec_items_str += f"{j + 1}: {self.id2entity[rec_item]}\n" + recommender_text = recommendation_template.format(rec_items_str) + + # judge whether success + for rec_label in items: + if rec_label in pred_items: + rec_success = True + break + + recommender_resp_entity = get_entity(recommender_text, self.entity2id) + + context.append(recommender_text) + entities += recommender_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + 'entity': recommender_resp_entity, + 'pred_items': pred_items, + 'rec_items': items, + 'rec_success': rec_success + }) + + # seeker + if rec_success is True: + seeker_text = "That's perfect, thank you!" + else: + seeker_text = "I don't like them." + + context_list.append({ + 'role': 'user', + 'content': seeker_text + }) + + context.append(seeker_text) + batch['context'] = [copy(context)] + + else: # choose to ask + recommender_text = option2template[selected_option] + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + }) + context.append(recommender_text) + batch['context'] = [copy(context)] + + # seeker + ask_attr = option2attr[selected_option] + + # update state + state[option2index[selected_option]] = -1e5 + + ans_attr_list = [] + id2info_items = [self.entityid2id[item] for item in items] + for id2info_item in id2info_items: + if str(id2info_item) in self.id2info and ask_attr in self.id2info[str(id2info_item)]: + ans_attr_list.extend(self.id2info[str(id2info_item)][ask_attr]) + if len(ans_attr_list) > 0: + seeker_text = ', '.join(list(set(ans_attr_list))) + else: + seeker_text = 'Sorry, no information about this, please choose another option.' + + context_list.append({ + 'role': 'user', + 'content': seeker_text, + 'entity': ans_attr_list, + }) + + seeker_resp_entities = get_entity(seeker_text, self.entity2id) + + context.append(seeker_text) + entities += seeker_resp_entities + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + if rec_success is True: + break + + return context_list + \ No newline at end of file diff --git a/crslab/evaluator/chat.py b/crslab/evaluator/chat.py new file mode 100644 index 0000000..f81c297 --- /dev/null +++ b/crslab/evaluator/chat.py @@ -0,0 +1,275 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import argparse +import copy +import json +import os +import re +import random +import time +import typing +import warnings +import tiktoken + +import numpy as np +import openai +import nltk + +from copy import copy +from tqdm import tqdm +from loguru import logger +from thefuzz import fuzz +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import get_entity, get_instruction + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +class Chat(): + + def __init__(self, turn_num, crs_model, dataset) -> None: + self.turn_num = turn_num + self.crs_model = crs_model + self.dataset_path = os.path.join(DATASET_PATH, dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, dataset, 'embed') + item_emb_list = [] + id2item_id = [] + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entity = {} + for k, v in self.entity2id.items(): + self.id2entity[int(v)] = k + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + + def chat(self, batch, turn_num): + + recommender_instruction, seeker_instruction_template = get_instruction() + + contexts_batch = batch['context'] + items_batch = batch['item'] + entity_batch = batch['entity'] + + for context, items, entities in zip(contexts_batch, items_batch, entity_batch): + + goal_item_list = [self.id2entity[item] for item in items] + goal_item_str = ', '.join(goal_item_list) + seeker_prompt = seeker_instruction_template.format(goal_item_str, goal_item_str, goal_item_str, goal_item_str) + + context_list = [] + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + seeker_prompt += f'Seeker: {text}\n' + context_list.append({ + 'role': 'user', + 'content': text, + }) + else: + seeker_prompt += f'Recommender: {text}\n' + context_list.append({ + 'role': 'assistant', + 'content': text + }) + + rec_success = False + recommendation_template = "I would recommend the following items: {}:" + + for i in range(0, turn_num): + # rec only + _, item_rank_arr = self.crs_model.recommend(batch) + + pred_items = item_rank_arr[0] + + for rec_label in items: + if rec_label in pred_items: + rec_success = True + break + + _, recommender_text = self.crs_model.converse(batch) + + if rec_success == True or i == turn_num - 1: + rec_items_str = '' + for j, rec_item in enumerate(pred_items): + rec_items_str += f"{j+1}: {self.id2entity[rec_item]}\n" + recommendation_template = recommendation_template.format(rec_items_str) + recommender_text = recommendation_template + recommender_text + + recommender_resp_entity = get_entity(recommender_text, self.entity2id) + + context.append(recommender_text) + entities += recommender_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + context_list.append({ + 'role': 'assistant', + 'content': recommender_text, + 'entity': recommender_resp_entity, + 'pred_items': pred_items, + 'rec_items': items, + 'rec_success': rec_success + }) + + seeker_prompt += f'Recommender: {recommender_text}\nSeeker:' + + # seeker + year_pattern = re.compile(r'\(\d+\)') + goal_item_no_year_list = [year_pattern.sub('', rec_item).strip() for rec_item in goal_item_list] + seeker_text = annotate_completion(seeker_prompt).strip() + + seeker_response_no_movie_list = [] + for sent in nltk.sent_tokenize(seeker_text): + use_sent = True + for rec_item_str in goal_item_list + goal_item_no_year_list: + if fuzz.partial_ratio(rec_item_str.lower(), sent.lower()) > 90: + use_sent = False + break + if use_sent is True: + seeker_response_no_movie_list.append(sent) + seeker_response = ' '.join(seeker_response_no_movie_list) + if not rec_success: + seeker_response = 'Sorry, ' + seeker_response + seeker_prompt += f' {seeker_response}\n' + + # public + seeker_resp_entity = get_entity(seeker_text, self.entity2id) + + context_list.append({ + 'role': 'user', + 'content': seeker_text, + 'entity': seeker_resp_entity, + }) + + context.append(seeker_text) + entities += seeker_resp_entity + entities = list(set(entities)) + + batch['context'] = [copy(context)] + batch['entity'] = [copy(entities)] + + if rec_success: + break + + # score persuativeness + encoding = tiktoken.encoding_for_model("text-davinci-003") + logit_bias = {encoding.encode(str(score))[0]: 10 for score in range(3)} + + persuasiveness_template = '''Does the explanation make you want to accept the recommendation? Please give your score. + If mention one of [{}], give 2. + Else if you think recommended items are worse than [{}], give 0. + Else if you think recommended items are comparable to [{}] according to the explanation, give 1. + Else if you think recommended items are better than [{}] according to the explanation, give 2. + Only answer the score number.''' + + persuasiveness_template = persuasiveness_template.format(goal_item_str, goal_item_str, goal_item_str, goal_item_str) + prompt_str_for_persuasiveness = seeker_prompt + persuasiveness_template + prompt_str_for_persuasiveness += "\nSeeker:" + persuasiveness_score = annotate_completion(prompt_str_for_persuasiveness, logit_bias).strip() + + context_list.append({ + 'persuasiveness_score': persuasiveness_score + }) + + return context_list \ No newline at end of file diff --git a/crslab/evaluator/embeddings.py b/crslab/evaluator/embeddings.py index a37adda..33044a6 100644 --- a/crslab/evaluator/embeddings.py +++ b/crslab/evaluator/embeddings.py @@ -8,26 +8,21 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - from crslab.download import DownloadableFile resources = { 'zh': { - 'version': '1.0', + 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', 'cc.zh.300.zip', 'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3' ) }, 'en': { - 'version': '1.0', + 'version': '0.2', 'file': DownloadableFile( - 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', 'cc.en.300.zip', '96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e' ) diff --git a/crslab/evaluator/rec.py b/crslab/evaluator/rec.py index ce2a6c4..dc15c67 100644 --- a/crslab/evaluator/rec.py +++ b/crslab/evaluator/rec.py @@ -25,17 +25,18 @@ class RecEvaluator(BaseEvaluator): optim_metrics: the metrics to optimize in training """ - def __init__(self, tensorboard=False): + def __init__(self, k_list=[1, 10, 50], tensorboard=False): super(RecEvaluator, self).__init__() self.rec_metrics = Metrics() self.optim_metrics = Metrics() self.tensorboard = tensorboard + self.k_list = k_list if self.tensorboard: self.writer = SummaryWriter(log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) self.reports_name = ['Recommendation Metrics', 'Optimization Metrics'] def rec_evaluate(self, ranks, label): - for k in [1, 10, 50]: + for k in self.k_list: if len(ranks) >= k: self.rec_metrics.add(f"hit@{k}", HitMetric.compute(ranks, label, k)) self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 67ac256..7341aba 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -7,31 +7,26 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - import os import time from collections import defaultdict import fasttext -from crslab.evaluator.base import BaseEvaluator -from crslab.evaluator.utils import nice_report from loguru import logger from nltk import ngrams from torch.utils.tensorboard import SummaryWriter -from ..config import EMBEDDING_PATH -from ..download import build +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import nice_report from .embeddings import resources from .metrics import * +from ..config import EMBEDDING_PATH +from ..download import build class StandardEvaluator(BaseEvaluator): """The evaluator for all kind of model(recommender, conversation, policy) - + Args: rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K dist_set: the set to record dist n-gram @@ -54,10 +49,8 @@ def __init__(self, language, tensorboard=False): # tensorboard self.tensorboard = tensorboard if self.tensorboard: - self.writer = SummaryWriter( - log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) - self.reports_name = ['Recommendation Metrics', - 'Generation Metrics', 'Optimization Metrics'] + self.writer = SummaryWriter(log_dir='runs/' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + self.reports_name = ['Recommendation Metrics', 'Generation Metrics', 'Optimization Metrics'] def _load_embedding(self, language): resource = resources[language] @@ -74,44 +67,34 @@ def _get_sent_embedding(self, sent): def rec_evaluate(self, ranks, label): for k in [1, 10, 50]: if len(ranks) >= k: - self.rec_metrics.add( - f"hit@{k}", HitMetric.compute(ranks, label, k)) - self.rec_metrics.add( - f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) - self.rec_metrics.add( - f"mrr@{k}", MRRMetric.compute(ranks, label, k)) + self.rec_metrics.add(f"hit@{k}", HitMetric.compute(ranks, label, k)) + self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) + self.rec_metrics.add(f"mrr@{k}", MRRMetric.compute(ranks, label, k)) def gen_evaluate(self, hyp, refs): if hyp: self.gen_metrics.add("f1", F1Metric.compute(hyp, refs)) for k in range(1, 5): - self.gen_metrics.add( - f"bleu@{k}", BleuMetric.compute(hyp, refs, k)) + self.gen_metrics.add(f"bleu@{k}", BleuMetric.compute(hyp, refs, k)) for token in ngrams(hyp, k): self.dist_set[f"dist@{k}"].add(token) self.dist_cnt += 1 hyp_emb = self._get_sent_embedding(hyp) ref_embs = [self._get_sent_embedding(ref) for ref in refs] - if len(ref_embs[0]) > 0: - self.gen_metrics.add( - 'greedy', GreedyMatch.compute(hyp_emb, ref_embs)) - self.gen_metrics.add( - 'average', EmbeddingAverage.compute(hyp_emb, ref_embs)) - self.gen_metrics.add( - 'extreme', VectorExtrema.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('greedy', GreedyMatch.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) + self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) def report(self, epoch=-1, mode='test'): for k, v in self.dist_set.items(): self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) - reports = [self.rec_metrics.report(), self.gen_metrics.report(), - self.optim_metrics.report()] + reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] if self.tensorboard and mode != 'test': for idx, task_report in enumerate(reports): for each_metric, value in task_report.items(): - self.writer.add_scalars( - f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) + self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) def reset_metrics(self): diff --git a/crslab/evaluator/utils.py b/crslab/evaluator/utils.py index 0f5f89c..dce24fa 100644 --- a/crslab/evaluator/utils.py +++ b/crslab/evaluator/utils.py @@ -16,6 +16,7 @@ import math import torch from typing import Union, Tuple +from rapidfuzz import fuzz, process from .metrics import Metric @@ -158,3 +159,23 @@ def nice_report(report) -> str: for k, v in output.items() } ) + +def get_entity(text, entity2id): + entity_list = list(entity2id.keys()) + extractions = process.extract(text, entity_list, scorer=fuzz.WRatio, limit=20) + extractions = [extraction[0] for extraction in extractions if extraction[1] >= 90] + extractions_ids = [entity2id[extraction] for extraction in extractions] + return extractions_ids + +def get_instruction(): + + recommender_instruction = '''You are a recommender chatting with the user to provide recommendation. You must follow the instructions below during chat. + If you do not have enough information about user preference, you should ask the user for his preference. + If you have enough information about user preference, you can give recommendation.\n''' + seeker_instruction_template = '''You are a seeker chatting with a recommender for recommendation. Your target items: {}. You must follow the instructions below during chat. + If the recommender recommend {}, you should accept. + If the recommender recommend other items, you should refuse them and provide the information about {}. You should never directly tell the target item title. + If the recommender asks for your preference, you should provide the information about {}. You should never directly tell the target item title\n. + ''' + + return recommender_instruction, seeker_instruction_template \ No newline at end of file diff --git a/crslab/model/__init__.py b/crslab/model/__init__.py index 9b32a78..419b4e3 100644 --- a/crslab/model/__init__.py +++ b/crslab/model/__init__.py @@ -41,7 +41,8 @@ 'GRU4REC': GRU4RECModel, 'Popularity': PopularityModel, 'TextCNN': TextCNNModel, - 'NTRD': NTRDModel + 'NTRD': NTRDModel, + 'ChatGPT': ChatGPTModel, } diff --git a/crslab/model/conversation/gpt2/gpt2.py b/crslab/model/conversation/gpt2/gpt2.py index 94dc405..c93badb 100644 --- a/crslab/model/conversation/gpt2/gpt2.py +++ b/crslab/model/conversation/gpt2/gpt2.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" GPT2 ==== @@ -23,16 +18,21 @@ """ +import os import torch -from crslab.model.base import BaseModel from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class GPT2Model(BaseModel): """ - + Attributes: context_truncate: A integer indicating the length of dialogue context. response_truncate: A integer indicating the length of dialogue response. @@ -52,10 +52,12 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['special_token_idx']['pad'] + self.pad_id = vocab['pad'] - self.dpath = opt['conv_pretrained_path'] - super(GPT2Model, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['gpt2'][language] + dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) + super(GPT2Model, self).__init__(opt, device, dpath, resource) def build_model(self): """build model""" @@ -93,8 +95,7 @@ def generate(self, context): context = context[..., -self.response_truncate + 1:] for i in range(self.response_truncate - 1): - outputs = self.model( - context, former_hidden_state) # (bs, c_t, v_s), + outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s) @@ -138,10 +139,8 @@ def generate_bs(self, context, beam=4): next_token_logits = last_hidden_state[:, -1, :] next_token_probs = torch.nn.functional.softmax(next_token_logits) topk = torch.topk(next_token_probs, beam, dim=-1) - probs = topk.values.reshape( - [batch_size, -1, beam]) # (bs, candidate, beam) - preds = topk.indices.reshape( - [batch_size, -1, beam]) # (bs, candidate, beam) + probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam) + preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam) for j in range(batch_size): all_candidates = [] @@ -153,8 +152,7 @@ def generate_bs(self, context, beam=4): seq_tmp.append(preds[j][n][k]) candidate = [seq_tmp, prob * probs[j][n][k]] all_candidates.append(candidate) - ordered = sorted( - all_candidates, key=lambda tup: tup[1], reverse=True) + ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) sequences[j] = ordered[:beam] res = [] diff --git a/crslab/model/conversation/transformer/transformer.py b/crslab/model/conversation/transformer/transformer.py index e0f7a8b..e36260e 100644 --- a/crslab/model/conversation/transformer/transformer.py +++ b/crslab/model/conversation/transformer/transformer.py @@ -20,13 +20,13 @@ import torch import torch.nn.functional as F -from crslab.model.base import BaseModel -from crslab.model.utils.functions import edge_to_pyg_format -from crslab.model.utils.modules.transformer import (TransformerDecoder, - TransformerEncoder) from loguru import logger from torch import nn +from crslab.model.base import BaseModel +from crslab.model.utils.functions import edge_to_pyg_format +from crslab.model.utils.modules.transformer import TransformerEncoder, TransformerDecoder + class TransformerModel(BaseModel): """ @@ -70,16 +70,16 @@ def __init__(self, opt, device, vocab, side_data): """ # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrain_embedding = side_data.get('embedding', None) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['pad_word'] + self.pad_entity_idx = vocab['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] diff --git a/crslab/model/crs/__init__.py b/crslab/model/crs/__init__.py index 0633769..ab75ed7 100644 --- a/crslab/model/crs/__init__.py +++ b/crslab/model/crs/__init__.py @@ -4,3 +4,4 @@ from .redial import * from .tgredial import * from .ntrd import * +from .chatgpt import * diff --git a/crslab/model/crs/chatgpt/__init__.py b/crslab/model/crs/chatgpt/__init__.py new file mode 100644 index 0000000..0ce0a8c --- /dev/null +++ b/crslab/model/crs/chatgpt/__init__.py @@ -0,0 +1 @@ +from .chatgpt import ChatGPTModel diff --git a/crslab/model/crs/chatgpt/chatgpt.py b/crslab/model/crs/chatgpt/chatgpt.py new file mode 100644 index 0000000..c786b38 --- /dev/null +++ b/crslab/model/crs/chatgpt/chatgpt.py @@ -0,0 +1,279 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import json +import os +import numpy as np +import openai +import typing +import tiktoken + +from tqdm import tqdm +from loguru import logger +from copy import copy +from tenacity import Retrying, retry_if_not_exception_type, _utils +from tenacity.stop import stop_base +from tenacity.wait import wait_base +from sklearn.metrics.pairwise import cosine_similarity + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.model.base import BaseModel + +def my_before_sleep(retry_state): + logger.debug( + f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + +class my_wait_exponential(wait_base): + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + def __call__(self, retry_state: "RetryCallState") -> float: + if retry_state.outcome == openai.error.Timeout: + return 0 + + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class my_stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome == openai.error.Timeout: + retry_state.attempt_number -= 1 + return retry_state.attempt_number >= self.max_attempt_number + +def annotate(item_text_list): + request_timeout = 6 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=my_before_sleep + ): + with attempt: + response = openai.Embedding.create( + model='text-embedding-ada-002', input=item_text_list, request_timeout=request_timeout + ) + request_timeout = min(30, request_timeout * 2) + + return response + +def annotate_completion(prompt, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, + retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)) + ): + with attempt: + response = openai.Completion.create( + model='text-davinci-003', prompt=prompt, temperature=0, max_tokens=128, stop='Recommender', + logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['text'] + request_timeout = min(300, request_timeout * 2) + + return response + +def annotate_chat(messages, logit_bias=None): + if logit_bias is None: + logit_bias = {} + + request_timeout = 20 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=my_before_sleep + ): + with attempt: + response = openai.ChatCompletion.create( + model='gpt-3.5-turbo', messages=messages, temperature=0, logit_bias=logit_bias, + request_timeout=request_timeout, + )['choices'][0]['message']['content'] + request_timeout = min(300, request_timeout * 2) + + return response + +class ChatGPTModel(BaseModel): + + def __init__(self, opt, device, vocab=None, side_data=None): + self.dataset = opt['dataset'] + self.dataset_path = os.path.join(DATASET_PATH, self.dataset) + self.item_embedding_path = os.path.join(SAVE_PATH, self.dataset, 'embed') + + with open(f"{self.dataset_path}/entity2id.json", 'r', encoding="utf-8") as f: + self.entity2id = json.load(f) + self.id2entity = {} + for entity, idx in self.entity2id.items(): + self.id2entity[idx] = entity + with open(f"{self.dataset_path}/id2info.json", 'r', encoding="utf-8") as f: + self.id2info = json.load(f) + + self.id2entityid = {} + for id, info in self.id2info.items(): + if info['name'] in self.entity2id: + self.id2entityid[id] = self.entity2id[info['name']] + + self.get_item_embedding() + super(ChatGPTModel, self).__init__(opt, device) + + def build_model(self, *args, **kwargs): + return super().build_model(*args, **kwargs) + + def get_item_embedding(self): + + item_emb_list = [] + id2item_id = [] + + if os.path.exists(self.item_embedding_path): + for i, file in tqdm(enumerate(os.listdir(self.item_embedding_path))): + item_id = os.path.splitext(file)[0] + if item_id in self.id2entityid: + id2item_id.append(item_id) + + with open(f'{self.item_embedding_path}/{file}', encoding='utf-8') as f: + embed = json.load(f) + item_emb_list.append(embed) + + self.id2item_id_arr = np.asarray(id2item_id) + self.item_emb_arr = np.asarray(item_emb_list) + + def get_instruction(self): + + recommender_instruction = '''You are a recommender chatting with the user to provide recommendation. You must follow the instructions below during chat. +If you do not have enough information about user preference, you should ask the user for his preference. +If you have enough information about user preference, you can give recommendation. The recommendation list must contain 10 items that are consistent with user preference. The recommendation list can contain items that the dialog mentioned before. The format of the recommendation list is: no. title. Don't mention anything other than the title of items in your recommendation list. +''' + + seeker_instruction_template = '''You are a seeker chatting with a recommender for recommendation. Your target items: {}. You must follow the instructions below during chat. + If the recommender recommend {}, you should accept. + If the recommender recommend other items, you should refuse them and provide the information about {}. You should never directly tell the target item title. + If the recommender asks for your preference, you should provide the information about {}. You should never directly tell the target item title\n. + ''' + + return recommender_instruction, seeker_instruction_template + + def recommend(self, batch, mode='test'): + + context_batch = batch['context'] + item_rank_arr_batch = [] + + for context in context_batch: + + if len(context) % 2 == 0: + context = [""] + context + + conv_str = "" + + for i, text in enumerate(context[-2:]): + if len(text) == 0: + continue + if i % 2 == 0: + conv_str += f'Seeker: {text}\n' + else: + conv_str += f'Recommender: {text}\n' + conv_embed = annotate(conv_str)['data'][0]['embedding'] + conv_embed = np.asarray(conv_embed).reshape(1, -1) + + sim_mat = cosine_similarity(conv_embed, self.item_emb_arr) + rank_arr = np.argsort(sim_mat, axis=-1).tolist() + rank_arr = np.flip(rank_arr, axis=-1)[:, :50] + item_rank_arr = self.id2item_id_arr[rank_arr].tolist() + # modify item_rank_arr, item_id -> entity_id + item_rank_arr = [self.id2entityid[item_id] for item_id in item_rank_arr[0]] + item_rank_arr_batch.append(item_rank_arr) + + loss = None + + return loss, item_rank_arr_batch + + def converse(self, batch, mode='test'): + recommender_instruction, seeker_instruction_template = self.get_instruction() + + context_batch = batch['context'] + item_batch = batch['item'] + + for context, items in zip(context_batch, item_batch): + + context_list = [{ + 'role': 'system', + 'content': recommender_instruction + }] + + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + context_list.append({ + 'role': 'user', + 'content': text, + }) + else: + context_list.append({ + 'role': 'assistant', + 'content': text + }) + + gen_str = annotate_chat(context_list) + gen_inputs = None + + return gen_inputs, gen_str + + def choose(self, gen_inputs, options, state, batch, mode='test'): + + context_batch = batch['context'] + + for context in context_batch: + + updated_options = [] + for i, st in enumerate(state): + if st >= 0: + updated_options.append(options[i]) + + logger.info(updated_options) + + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + logit_bias = {encoding.encode(option)[0]: 20 for option in updated_options} + + context_list = [] + if len(context) % 2 == 0: + context = [""] + context + + for i, text in enumerate(context): + if len(text) == 0: + continue + if i % 2 == 0: + role_str = 'user' + else: + role_str = 'assistant' + context_list.append({ + 'role': role_str, + 'content': text + }) + + logger.info(context_list) + + response_op = annotate_chat(context_list, logit_bias=logit_bias) + return response_op[0] \ No newline at end of file diff --git a/crslab/model/crs/inspired/inspired_conv.py b/crslab/model/crs/inspired/inspired_conv.py index 10f79f6..99e7ca9 100644 --- a/crslab/model/crs/inspired/inspired_conv.py +++ b/crslab/model/crs/inspired/inspired_conv.py @@ -2,17 +2,15 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - - +import os import torch -from crslab.model.base import BaseModel -from transformers import GPT2Config, GPT2LMHeadModel +from transformers import GPT2LMHeadModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources from .modules import SequenceCrossEntropyLoss @@ -38,11 +36,13 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['special_token_idx']['pad'] + self.pad_id = vocab['pad'] self.label_smoothing = opt['conv']['label_smoothing'] if 'label_smoothing' in opt['conv'] else -1 - self.dpath = opt['conv_pretrained_path'] - super(InspiredConvModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['gpt2'][language] + dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) + super(InspiredConvModel, self).__init__(opt, device, dpath, resource) def build_model(self): """build model for seeker and recommender separately""" @@ -67,26 +67,24 @@ def converse(self, batch, mode): past = None lm_logits_all = [] - support_up_limits = self.model_sk.config.n_positions if mode != 'test': for turn, iter in enumerate(input_ids_iters): if (roles[turn] == 0): # considering that gpt2 only supports up to 1024 tokens - if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: + if past is not None and past[0].shape[3] + iter.shape[1] > 1024: past = None outputs = self.model_sk(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) else: - if past is not None and past[0][0].shape[-2] + iter.shape[1] > support_up_limits: + if past is not None and past[0].shape[3] + iter.shape[1] > 1024: past = None outputs = self.model_rm(iter, past_key_values=past) lm_logits, past = outputs.logits, outputs.past_key_values lm_logits_all.append(lm_logits) - # (b_s, seq_len, vocab_size) - lm_logits_all = torch.cat(lm_logits_all, dim=0) + lm_logits_all = torch.cat(lm_logits_all, dim=0) # (b_s, seq_len, vocab_size) # index from 1 to self.reponse_truncate is valid response loss = self.calculate_loss( @@ -118,11 +116,9 @@ def generate(self, roles, context): context_iters = context.unsqueeze(1) for turn, iter in enumerate(context_iters): if roles[turn] == 0: - outputs = self.model_sk( - iter, former_hidden_state) # (1, s_l, v_s), + outputs = self.model_sk(iter, former_hidden_state) # (1, s_l, v_s), else: - outputs = self.model_rm( - iter, former_hidden_state) # (1, s_l, v_s), + outputs = self.model_rm(iter, former_hidden_state) # (1, s_l, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values last_hidden_state_all.append(last_hidden_state) diff --git a/crslab/model/crs/inspired/inspired_rec.py b/crslab/model/crs/inspired/inspired_rec.py index bebbd83..67948f5 100644 --- a/crslab/model/crs/inspired/inspired_rec.py +++ b/crslab/model/crs/inspired/inspired_rec.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" BERT ==== @@ -23,12 +18,17 @@ """ +import os -from crslab.model.base import BaseModel from loguru import logger from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class InspiredRecModel(BaseModel): """ @@ -50,8 +50,10 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - self.dpath = opt['rec_pretrained_path'] - super(InspiredRecModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(InspiredRecModel, self).__init__(opt, device, dpath, resource) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/model/crs/kbrd/kbrd.py b/crslab/model/crs/kbrd/kbrd.py index 0e92e18..987151d 100644 --- a/crslab/model/crs/kbrd/kbrd.py +++ b/crslab/model/crs/kbrd/kbrd.py @@ -21,15 +21,15 @@ import torch import torch.nn.functional as F -from crslab.model.base import BaseModel -from crslab.model.utils.functions import edge_to_pyg_format -from crslab.model.utils.modules.attention import SelfAttentionBatch -from crslab.model.utils.modules.transformer import (TransformerDecoder, - TransformerEncoder) from loguru import logger from torch import nn from torch_geometric.nn import RGCNConv +from crslab.model.base import BaseModel +from crslab.model.utils.functions import edge_to_pyg_format +from crslab.model.utils.modules.attention import SelfAttentionBatch +from crslab.model.utils.modules.transformer import TransformerDecoder, TransformerEncoder + class KBRDModel(BaseModel): """ @@ -74,9 +74,9 @@ def __init__(self, opt, device, vocab, side_data): self.device = device self.gpu = opt.get("gpu", [-1]) # vocab - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] self.vocab_size = vocab['vocab_size'] self.token_emb_dim = opt.get('token_emb_dim', 300) self.pretrain_embedding = side_data.get('embedding', None) @@ -84,8 +84,7 @@ def __init__(self, opt, device, vocab, side_data): self.n_entity = vocab['n_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] - self.edge_idx, self.edge_type = edge_to_pyg_format( - entity_kg['edge'], 'RGCN') + self.edge_idx, self.edge_type = edge_to_pyg_format(entity_kg['edge'], 'RGCN') self.edge_idx = self.edge_idx.to(device) self.edge_type = self.edge_type.to(device) self.num_bases = opt.get('num_bases', 8) @@ -99,8 +98,7 @@ def __init__(self, opt, device, vocab, side_data): self.attention_dropout = opt.get('attention_dropout', 0.0) self.relu_dropout = opt.get('relu_dropout', 0.1) self.embeddings_scale = opt.get('embedding_scale', True) - self.learn_positional_embeddings = opt.get( - 'learn_positional_embeddings', False) + self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False) self.reduction = opt.get('reduction', False) self.n_positions = opt.get('n_positions', 1024) self.longest_label = opt.get('longest_label', 1) @@ -120,17 +118,13 @@ def _build_embedding(self): torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding( - self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, - mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_( - self.token_embedding.weight[self.pad_token_idx], 0) + self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) logger.debug('[Build embedding]') def _build_kg_layer(self): - self.kg_encoder = RGCNConv( - self.n_entity, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) + self.kg_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim) logger.debug('[Build kg layer]') @@ -140,8 +134,7 @@ def _build_recommendation_layer(self): logger.debug('[Build recommendation layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor( - [self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) self.dialog_encoder = TransformerEncoder( self.n_heads, self.n_layers, @@ -182,8 +175,7 @@ def encode_user(self, entity_lists, kg_embedding): user_repr_list = [] for entity_list in entity_lists: if entity_list is None: - user_repr_list.append(torch.zeros( - self.user_emb_dim, device=self.device)) + user_repr_list.append(torch.zeros(self.user_emb_dim, device=self.device)) continue user_repr = kg_embedding[entity_list] user_repr = self.kg_attn(user_repr) @@ -209,8 +201,7 @@ def decode_forced(self, encoder_states, user_embedding, resp): inputs = torch.cat([self._starts(bsz), inputs], 1) latent, _ = self.decoder(inputs, encoder_states) token_logits = F.linear(latent, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu( - self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits _, preds = sum_logits.max(dim=-1) return sum_logits, preds @@ -222,19 +213,16 @@ def decode_greedy(self, encoder_states, user_embedding): incr_state = None logits = [] for i in range(self.longest_label): - scores, incr_state = self.decoder( - xs, encoder_states, incr_state) # incr_state is always None + scores, incr_state = self.decoder(xs, encoder_states, incr_state) # incr_state is always None scores = scores[:, -1:, :] token_logits = F.linear(scores, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu( - self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits probs, preds = sum_logits.max(dim=-1) logits.append(scores) xs = torch.cat([xs, preds], dim=1) # check if everyone has generated an end token - all_finished = ((xs == self.end_token_idx).sum( - dim=1) > 0).sum().item() == bsz + all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz if all_finished: break logits = torch.cat(logits, 1) @@ -252,8 +240,7 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): for j in range(bsz): text = sequences[j][d][0] xs.append(text) - xs = torch.stack(xs).reshape( - beam, bsz, -1) # (beam, batch_size, _) + xs = torch.stack(xs).reshape(beam, bsz, -1) # (beam, batch_size, _) with torch.no_grad(): if i == 1: @@ -261,18 +248,15 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): encoder_states = (encoder_states[0].repeat(beam, 1, 1), encoder_states[1].repeat(beam, 1, 1)) - scores, _ = self.decoder(xs.reshape( - len(sequences[0]) * bsz, -1), encoder_states) + scores, _ = self.decoder(xs.reshape(len(sequences[0]) * bsz, -1), encoder_states) scores = scores[:, -1:, :] token_logits = F.linear(scores, self.token_embedding.weight) - user_logits = self.user_proj_2(torch.relu( - self.user_proj_1(user_embedding))).unsqueeze(1) + user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) sum_logits = token_logits + user_logits logits = sum_logits.reshape(len(sequences[0]), bsz, 1, -1) scores = scores.reshape(len(sequences[0]), bsz, 1, -1) - # turn into probabilities,in case of negative numbers - logits = torch.nn.functional.softmax(logits) + logits = torch.nn.functional.softmax(logits) # turn into probabilities,in case of negative numbers probs, preds = logits.topk(beam, dim=-1) # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam @@ -285,20 +269,15 @@ def decode_beam_search(self, encoder_states, user_embedding, beam=4): if score == []: score_tmp = scores[n][j][0].unsqueeze(0) else: - score_tmp = torch.cat( - (score, scores[n][j][0].unsqueeze(0)), dim=0) - seq_tmp = torch.cat( - (xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) - candidate = [seq_tmp, score_tmp, - prob * probs[n][j][0][k]] + score_tmp = torch.cat((score, scores[n][j][0].unsqueeze(0)), dim=0) + seq_tmp = torch.cat((xs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) + candidate = [seq_tmp, score_tmp, prob * probs[n][j][0][k]] all_candidates.append(candidate) - ordered = sorted( - all_candidates, key=lambda tup: tup[2], reverse=True) + ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) sequences[j] = ordered[:beam] # check if everyone has generated an end token - all_finished = ((xs == self.end_token_idx).sum( - dim=1) > 0).sum().item() == bsz + all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz if all_finished: break logits = torch.stack([seq[0][1] for seq in sequences]) @@ -313,8 +292,7 @@ def converse(self, batch, mode): encoder_state = self.dialog_encoder(context_tokens) if mode != 'test': self.longest_label = max(self.longest_label, response.shape[1]) - logits, preds = self.decode_forced( - encoder_state, user_embedding, response) + logits, preds = self.decode_forced(encoder_state, user_embedding, response) logits = logits.view(-1, logits.shape[-1]) labels = response.view(-1) return self.conv_loss(logits, labels), preds @@ -335,4 +313,4 @@ def freeze_parameters(self): freeze_models = [self.kg_encoder, self.kg_attn, self.rec_bias] for model in freeze_models: for p in model.parameters(): - p.requires_grad = False + p.requires_grad = False \ No newline at end of file diff --git a/crslab/model/crs/kgsf/kgsf.py b/crslab/model/crs/kgsf/kgsf.py index a428aaf..57590f4 100644 --- a/crslab/model/crs/kgsf/kgsf.py +++ b/crslab/model/crs/kgsf/kgsf.py @@ -7,11 +7,6 @@ # @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" KGSF ==== @@ -37,8 +32,8 @@ from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder - from .modules import GateLayer, TransformerDecoderKG +from .resources import resources class KGSFModel(BaseModel): @@ -86,22 +81,20 @@ def __init__(self, opt, device, vocab, side_data): self.gpu = opt.get("gpu", [-1]) # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) - self.copy_mask = torch.as_tensor(vocab['copy_mask'], dtype=torch.bool, device=self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['pad_word'] + self.pad_entity_idx = vocab['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] - self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format( - entity_edges, 'RGCN') + self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN') self.entity_edge_idx = self.entity_edge_idx.to(device) self.entity_edge_type = self.entity_edge_type.to(device) word_edges = side_data['word_kg']['edge'] @@ -123,9 +116,10 @@ def __init__(self, opt, device, vocab, side_data): self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask - self.dataset = opt['dataset'] - self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) - super(KGSFModel, self).__init__(opt, device, self.dpath) + dataset = opt['dataset'] + dpath = os.path.join(MODEL_PATH, "kgsf", dataset) + resource = resources[dataset] + super(KGSFModel, self).__init__(opt, device, dpath, resource) def build_model(self): self._init_embeddings() @@ -140,32 +134,24 @@ def _init_embeddings(self): torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding( - self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, - mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_( - self.token_embedding.weight[self.pad_token_idx], 0) - - self.word_kg_embedding = nn.Embedding( - self.n_word, self.kg_emb_dim, self.pad_word_idx) - nn.init.normal_(self.word_kg_embedding.weight, - mean=0, std=self.kg_emb_dim ** -0.5) + self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) + + self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx) + nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0) logger.debug('[Finish init embeddings]') def _build_kg_layer(self): # db encoder - self.entity_encoder = RGCNConv( - self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) - self.entity_self_attn = SelfAttentionSeq( - self.kg_emb_dim, self.kg_emb_dim) + self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) + self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # concept encoder self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim) - self.word_self_attn = SelfAttentionSeq( - self.kg_emb_dim, self.kg_emb_dim) + self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # gate mechanism self.gate_layer = GateLayer(self.kg_emb_dim) @@ -186,8 +172,7 @@ def _build_recommendation_layer(self): logger.debug('[Finish build rec layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor( - [self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) self.conv_encoder = TransformerEncoder( n_heads=self.n_heads, n_layers=self.n_layers, @@ -212,6 +197,8 @@ def _build_conversation_layer(self): self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) + self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), + ).to(self.device) self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, @@ -239,19 +226,15 @@ def pretrain_infomax(self, batch): if loss_mask.item() == 0: return None - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) word_representations = word_graph_representations[words] word_padding_mask = words.eq(self.pad_word_idx) # (bs, seq_len) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) - info_predict = F.linear( - word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) + info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) loss = self.infomax_loss(info_predict, entity_labels) / loss_mask return loss @@ -263,27 +246,20 @@ def recommend(self, batch, mode): """ context_entities, context_words, entities, movie = batch - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq( - self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq( - self.pad_word_idx) # (bs, word_len) + entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, word_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn( - entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) user_rep = self.gate_layer(entity_attn_rep, word_attn_rep) - rec_scores = F.linear( - user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) + rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) rec_loss = self.rec_loss(rec_scores, movie) @@ -294,8 +270,7 @@ def recommend(self, batch, mode): word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) - info_loss = self.infomax_loss( - info_predict, entities) / info_loss_mask + info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask return rec_loss, info_loss, rec_scores @@ -325,8 +300,7 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze( 0) # (bs, seq_len, vocab_size) - # (bs, seq_len, vocab_size) - gen_logits = F.linear(dialog_latent, self.token_embedding.weight) + gen_logits = F.linear(dialog_latent, self.token_embedding.weight) # (bs, seq_len, vocab_size) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1) return sum_logits, preds @@ -343,19 +317,16 @@ def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm( - torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output( - copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1).long() logits.append(sum_logits) inputs = torch.cat((inputs, preds), dim=1) - finished = ((inputs == self.end_token_idx).sum( - dim=-1) > 0).sum().item() == batch_size + finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size if finished: break logits = torch.cat(logits, dim=1) @@ -386,8 +357,7 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at for j in range(batch_size): text = sequences[j][d][0] inputs.append(text) - inputs = torch.stack(inputs).reshape( - beam, batch_size, -1) # (beam, batch_size, _) + inputs = torch.stack(inputs).reshape(beam, batch_size, -1) # (beam, batch_size, _) with torch.no_grad(): dialog_latent, incr_state = self.conv_decoder( @@ -398,19 +368,15 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm( - torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output( - copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) - gen_logits = F.linear( - dialog_latent, self.token_embedding.weight) + copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits logits = sum_logits.reshape(len(sequences[0]), batch_size, 1, -1) # turn into probabilities,in case of negative numbers - probs, preds = torch.nn.functional.softmax( - logits).topk(beam, dim=-1) + probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1) # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam @@ -423,20 +389,15 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at if logit == []: logit_tmp = logits[n][j][0].unsqueeze(0) else: - logit_tmp = torch.cat( - (logit, logits[n][j][0].unsqueeze(0)), dim=0) - seq_tmp = torch.cat( - (inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) - candidate = [seq_tmp, logit_tmp, - prob * probs[n][j][0][k]] + logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0) + seq_tmp = torch.cat((inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) + candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]] all_candidates.append(candidate) - ordered = sorted( - all_candidates, key=lambda tup: tup[2], reverse=True) + ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) sequences[j] = ordered[:beam] # check if everyone has generated an end token - all_finished = ((inputs == self.end_token_idx).sum( - dim=1) > 0).sum().item() == batch_size + all_finished = ((inputs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size if all_finished: break logits = torch.stack([seq[0][1] for seq in sequences]) @@ -446,23 +407,17 @@ def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_at def converse(self, batch, mode): context_tokens, context_entities, context_words, response = batch - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq( - self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq( - self.pad_word_idx) # (bs, seq_len) + entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, seq_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn( - entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) # encoder-decoder tokens_encoding = self.conv_encoder(context_tokens) @@ -489,11 +444,11 @@ def converse(self, batch, mode): def forward(self, batch, stage, mode): if len(self.gpu) >= 2: # forward function operates on different gpus, the weight of graph network need to be copied to other gpu - self.entity_edge_idx = self.entity_edge_idx.cuda( - torch.cuda.current_device()) - self.entity_edge_type = self.entity_edge_type.cuda( - torch.cuda.current_device()) + self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device()) + self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) + self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), + ).cuda(torch.cuda.current_device()) if stage == "pretrain": loss = self.pretrain_infomax(batch) elif stage == "rec": diff --git a/crslab/model/crs/kgsf/resources.py b/crslab/model/crs/kgsf/resources.py new file mode 100644 index 0000000..c32dcd2 --- /dev/null +++ b/crslab/model/crs/kgsf/resources.py @@ -0,0 +1,62 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/13 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/15 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'ReDial': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', + 'kgsf_redial.zip', + 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', + ), + }, + 'TGReDial': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', + 'kgsf_tgredial.zip', + 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', + ), + }, + 'GoRecDial': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1', + 'kgsf_gorecdial.zip', + 'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633', + ) + }, + 'OpenDialKG': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', + 'kgsf_opendialkg.zip', + '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' + ) + }, + 'Inspired': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', + 'kgsf_inspired.zip', + '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' + ) + }, + 'DuRecDial': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', + 'kgsf_durecdial.zip', + 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' + ) + } +} diff --git a/crslab/model/crs/ntrd/ntrd.py b/crslab/model/crs/ntrd/ntrd.py index 010dfd6..0f971b4 100644 --- a/crslab/model/crs/ntrd/ntrd.py +++ b/crslab/model/crs/ntrd/ntrd.py @@ -3,10 +3,6 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com r""" NTRD @@ -32,10 +28,8 @@ from crslab.model.utils.functions import edge_to_pyg_format from crslab.model.utils.modules.attention import SelfAttentionSeq from crslab.model.utils.modules.transformer import TransformerEncoder - -from .modules import (GateLayer, TransformerDecoderKG, - TransformerDecoderSelection) - +from .modules import GateLayer, TransformerDecoderKG,TransformerDecoderSelection +from .resources import resources class NTRDModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -52,24 +46,22 @@ def __init__(self, opt, device, vocab, side_data): self.gpu = opt.get("gpu", [-1]) # vocab self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) - self.replace_token = opt.get('replace_token', None) + self.replace_token = opt.get('replace_token',None) self.replace_token_idx = vocab[self.replace_token] - self.copy_mask = torch.as_tensor(vocab['copy_mask'], dtype=torch.bool, device=self.device) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] - self.pad_word_idx = vocab['special_token_idx']['pad_word'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_word_idx = vocab['pad_word'] + self.pad_entity_idx = vocab['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] - self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format( - entity_edges, 'RGCN') + self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN') self.entity_edge_idx = self.entity_edge_idx.to(device) self.entity_edge_type = self.entity_edge_type.to(device) word_edges = side_data['word_kg']['edge'] @@ -90,17 +82,18 @@ def __init__(self, opt, device, vocab, side_data): self.reduction = opt['reduction'] self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) - # selector + # selector self.n_movies = opt['n_movies'] # self.n_movies_label = opt['n_movies_label'] - self.n_movies_label = 64362 # the number of entity2id + self.n_movies_label = 64362 # the number of entity2id # copy mask - self.dataset = opt['dataset'] - self.dpath = os.path.join(MODEL_PATH, "kgsf", self.dataset) + dataset = opt['dataset'] + dpath = os.path.join(MODEL_PATH, "kgsf", dataset) + resource = resources[dataset] # loss weight self.gen_loss_weight = opt['gen_loss_weight'] - super(NTRDModel, self).__init__(opt, device, self.dpath) - + super(NTRDModel, self).__init__(opt, device, dpath, resource) + def build_model(self): self._init_embeddings() self._build_kg_layer() @@ -108,39 +101,31 @@ def build_model(self): self._build_recommendation_layer() self._build_conversation_layer() self._build_movie_selector() - + def _init_embeddings(self): if self.pretrained_embedding is not None: self.token_embedding = nn.Embedding.from_pretrained( torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: - self.token_embedding = nn.Embedding( - self.vocab_size, self.token_emb_dim, self.pad_token_idx) - nn.init.normal_(self.token_embedding.weight, - mean=0, std=self.kg_emb_dim ** -0.5) - nn.init.constant_( - self.token_embedding.weight[self.pad_token_idx], 0) - - self.word_kg_embedding = nn.Embedding( - self.n_word, self.kg_emb_dim, self.pad_word_idx) - nn.init.normal_(self.word_kg_embedding.weight, - mean=0, std=self.kg_emb_dim ** -0.5) + self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) + + self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx) + nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0) logger.debug('[Finish init embeddings]') def _build_kg_layer(self): # db encoder - self.entity_encoder = RGCNConv( - self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) - self.entity_self_attn = SelfAttentionSeq( - self.kg_emb_dim, self.kg_emb_dim) + self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) + self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # concept encoder self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim) - self.word_self_attn = SelfAttentionSeq( - self.kg_emb_dim, self.kg_emb_dim) + self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # gate mechanism self.gate_layer = GateLayer(self.kg_emb_dim) @@ -161,8 +146,7 @@ def _build_recommendation_layer(self): logger.debug('[Finish build rec layer]') def _build_conversation_layer(self): - self.register_buffer('START', torch.tensor( - [self.start_token_idx], dtype=torch.long)) + self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) self.conv_encoder = TransformerEncoder( n_heads=self.n_heads, n_layers=self.n_layers, @@ -187,12 +171,14 @@ def _build_conversation_layer(self): self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) + copy_mask = np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool) if self.replace_token: if self.replace_token_idx < len(copy_mask): copy_mask[self.replace_token_idx] = False else: - copy_mask = np.insert(copy_mask, len(copy_mask), False) - self.copy_mask = torch.as_tensor(copy_mask, device=self.device) + copy_mask = np.insert(copy_mask,len(copy_mask),False) + self.copy_mask = torch.as_tensor(copy_mask).to(self.device) + self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, @@ -220,19 +206,15 @@ def pretrain_infomax(self, batch): if loss_mask.item() == 0: return None - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) word_representations = word_graph_representations[words] word_padding_mask = words.eq(self.pad_word_idx) # (bs, seq_len) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) - info_predict = F.linear( - word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) + info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) loss = self.infomax_loss(info_predict, entity_labels) / loss_mask return loss @@ -252,8 +234,7 @@ def _build_movie_selector(self): embeddings_scale=self.embeddings_scale, n_positions=self.n_positions, ) - self.matching_linear = nn.Linear( - self.token_emb_dim, self.n_movies_label) + self.matching_linear = nn.Linear(self.token_emb_dim,self.n_movies_label) self.sel_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx) def recommend(self, batch, mode): @@ -264,27 +245,20 @@ def recommend(self, batch, mode): """ context_entities, context_words, entities, movie = batch - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq( - self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq( - self.pad_word_idx) # (bs, word_len) + entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, word_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn( - entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) user_rep = self.gate_layer(entity_attn_rep, word_attn_rep) - rec_scores = F.linear( - user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) + rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) rec_loss = self.rec_loss(rec_scores, movie) @@ -295,8 +269,7 @@ def recommend(self, batch, mode): word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) - info_loss = self.infomax_loss( - info_predict, entities) / info_loss_mask + info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask return rec_loss, info_loss, rec_scores @@ -310,27 +283,21 @@ def freeze_parameters(self): def _starts(self, batch_size): """Return bsz start tokens.""" return self.START.detach().expand(batch_size, 1) - + def converse(self, batch, mode): context_tokens, context_entities, context_words, response, all_movies = batch - entity_graph_representations = self.entity_encoder( - None, self.entity_edge_idx, self.entity_edge_type) - word_graph_representations = self.word_encoder( - self.word_kg_embedding.weight, self.word_edges) + entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) + word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) - entity_padding_mask = context_entities.eq( - self.pad_entity_idx) # (bs, entity_len) - word_padding_mask = context_words.eq( - self.pad_word_idx) # (bs, seq_len) + entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) + word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, seq_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] - entity_attn_rep = self.entity_self_attn( - entity_representations, entity_padding_mask) - word_attn_rep = self.word_self_attn( - word_representations, word_padding_mask) + entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) + word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) # encoder-decoder tokens_encoding = self.conv_encoder(context_tokens) @@ -340,55 +307,47 @@ def converse(self, batch, mode): conv_word_reps = self.conv_word_norm(word_representations) if mode != 'test': - logits, preds, latent = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, - entity_padding_mask, - conv_word_reps, conv_word_emb, word_padding_mask, - response) + logits, preds,latent = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, + entity_padding_mask, + conv_word_reps, conv_word_emb, word_padding_mask, + response) logits_ = logits.view(-1, logits.shape[-1]) response_ = response.view(-1) gen_loss = self.conv_loss(logits_, response_) - assert torch.sum(all_movies != 0, dim=(0, 1)) == torch.sum( - (response == 30000), dim=(0, 1)) # 30000 means the idx of [ITEM] - masked_for_selection_token = (response == self.replace_token_idx) + assert torch.sum(all_movies!=0, dim=(0,1)) == torch.sum((response == 30000), dim=(0,1)) #30000 means the idx of [ITEM] + masked_for_selection_token = (response == self.replace_token_idx) - matching_tensor, _ = self.movie_selector( - latent, tokens_encoding, conv_word_reps, word_padding_mask) + matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask) matching_logits_ = self.matching_linear(matching_tensor) - matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze( - -1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) + matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) - all_movies = torch.masked_select(all_movies, (all_movies != 0)) - matching_logits = matching_logits.view(-1, - matching_logits.shape[-1]) + all_movies = torch.masked_select(all_movies,(all_movies != 0)) + matching_logits = matching_logits.view(-1,matching_logits.shape[-1]) all_movies = all_movies.view(-1) - selection_loss = self.sel_loss(matching_logits, all_movies) - return gen_loss, selection_loss, preds + selection_loss = self.sel_loss(matching_logits,all_movies) + return gen_loss,selection_loss, preds else: - logits, preds, latent = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, - entity_padding_mask, - conv_word_reps, conv_word_emb, word_padding_mask) - - preds_for_selection = preds[:, 1:] # skip the start_ind - masked_for_selection_token = ( - preds_for_selection == self.replace_token_idx) - - matching_tensor, _ = self.movie_selector( - latent, tokens_encoding, conv_word_reps, word_padding_mask) + logits, preds,latent = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, + entity_padding_mask, + conv_word_reps, conv_word_emb, word_padding_mask) + + preds_for_selection = preds[:, 1:] # skip the start_ind + masked_for_selection_token = (preds_for_selection == self.replace_token_idx) + + matching_tensor,_ = self.movie_selector(latent,tokens_encoding,conv_word_reps,word_padding_mask) matching_logits_ = self.matching_linear(matching_tensor) - matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze( - -1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) + matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) if matching_logits.shape[0] is not 0: - #W1: greedy - _, matching_pred = matching_logits.max( - dim=-1) # [bsz * dynamic_movie_nums] + #W1: greedy + _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] else: matching_pred = None - return preds, matching_pred, matching_logits_ - + return preds,matching_pred,matching_logits_ + def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask, word_reps, word_emb_attn, word_mask): batch_size = token_encoding[0].shape[0] @@ -403,19 +362,16 @@ def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e latents.append(dialog_latent) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) - copy_latent = self.copy_norm( - torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) + copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) - copy_logits = self.copy_output( - copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) + copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1).long() logits.append(sum_logits) inputs = torch.cat((inputs, preds), dim=1) - finished = ((inputs == self.end_token_idx).sum( - dim=-1) > 0).sum().item() == batch_size + finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size if finished: break logits = torch.cat(logits, dim=1) @@ -430,7 +386,7 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e dialog_latent, _ = self.conv_decoder(inputs, token_encoding, word_reps, word_mask, entity_reps, entity_mask) # (bs, seq_len, dim) - + entity_latent = entity_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) word_latent = word_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) copy_latent = self.copy_norm( @@ -438,25 +394,25 @@ def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, e copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze( 0) # (bs, seq_len, vocab_size) - # (bs, seq_len, vocab_size) - gen_logits = F.linear(dialog_latent, self.token_embedding.weight) + gen_logits = F.linear(dialog_latent, self.token_embedding.weight) # (bs, seq_len, vocab_size) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1) return sum_logits, preds, dialog_latent + + def forward(self, batch, stage, mode): if len(self.gpu) >= 2: # forward function operates on different gpus, the weight of graph network need to be copied to other gpu - self.entity_edge_idx = self.entity_edge_idx.cuda( - torch.cuda.current_device()) - self.entity_edge_type = self.entity_edge_type.cuda( - torch.cuda.current_device()) + self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device()) + self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) - + self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), + ).cuda(torch.cuda.current_device()) if stage == "pretrain": loss = self.pretrain_infomax(batch) elif stage == "rec": loss = self.recommend(batch, mode) elif stage == "conv": loss = self.converse(batch, mode) - return loss + return loss \ No newline at end of file diff --git a/crslab/model/crs/ntrd/resources.py b/crslab/model/crs/ntrd/resources.py new file mode 100644 index 0000000..c32dcd2 --- /dev/null +++ b/crslab/model/crs/ntrd/resources.py @@ -0,0 +1,62 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/13 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/15 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'ReDial': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXl2bhU82O5Itp9K4Mh41mYB69BKPEvMcKwZRstfYZUB1g?download=1', + 'kgsf_redial.zip', + 'f627841644a184079acde1b0185e3a223945061c3a591f4bc0d7f62e7263f548', + ), + }, + 'TGReDial': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETzJ0-QnguRKiKO_ktrTDZQBZHKom4-V5SJ9mhesfXzrWQ?download=1', + 'kgsf_tgredial.zip', + 'c9d054b653808795035f77cb783227e6e9a938e5bedca4d7f88c6dfb539be5d1', + ), + }, + 'GoRecDial': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ER5u2yMmgDNFvHuW6lKZLEkBKZkOkxMtZGK0bBQ-jvfLNw?download=1', + 'kgsf_gorecdial.zip', + 'f2f57ebb8f688f38a98ee41fe3a87e9362aed945ec9078869407f799da322633', + ) + }, + 'OpenDialKG': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQgebOKypMlPr18KJ6uGeDABtqTbMQYVYNWNR_DaAZ1Wvg?download=1', + 'kgsf_opendialkg.zip', + '89b785b23478b1d91d6ab4f34a3658e82b52dcbb73828713a9b369fa49db9e61' + ) + }, + 'Inspired': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXQGUxjGQ-ZKpzTnUYOMavABMUAxb0JwkiIMAPp5DIvsNw?download=1', + 'kgsf_inspired.zip', + '23dfc031a3c71f2a52e29fe0183e1a501771b8d431852102ba6fd83d971f928d' + ) + }, + 'DuRecDial': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ed9-qLkK0bNCk5AAvJpWU3cBC-cXks-6JlclYp08AFovyw?download=1', + 'kgsf_durecdial.zip', + 'f9a39c2382efe88d80ef14d7db8b4cbaf3a6eb92a33e018dfc9afba546ba08ef' + ) + } +} diff --git a/crslab/model/crs/redial/modules.py b/crslab/model/crs/redial/modules.py index 6f436e7..a726524 100644 --- a/crslab/model/crs/redial/modules.py +++ b/crslab/model/crs/redial/modules.py @@ -7,16 +7,12 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + from crslab.model.utils.functions import sort_for_packed_sequence -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class HRNN(nn.Module): @@ -61,12 +57,10 @@ def get_utterance_encoding(self, context, utterance_lengths): """ batch_size, max_conv_length = context.shape[:2] utterance_lengths = utterance_lengths.reshape(-1) # (bs * conv_len) - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( - utterance_lengths) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(utterance_lengths) # reshape and reorder - sorted_utterances = context.view( - batch_size * max_conv_length, -1).index_select(0, sorted_idx) + sorted_utterances = context.view(batch_size * max_conv_length, -1).index_select(0, sorted_idx) # consider valid sequences only(length > 0) num_positive_lengths = torch.sum(utterance_lengths > 0) @@ -77,21 +71,18 @@ def get_utterance_encoding(self, context, utterance_lengths): if self.use_dropout: embedded = self.dropout(embedded) - packed_utterances = pack_padded_sequence( - embedded, sorted_lengths.cpu(), batch_first=True) + packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) - utterance_encoding = torch.cat( - (utterance_encoding[-1], utterance_encoding[-2]), 1) + utterance_encoding = torch.cat((utterance_encoding[-1], utterance_encoding[-2]), 1) if self.use_dropout: utterance_encoding = self.dropout(utterance_encoding) # complete the missing sequences (of length 0) if num_positive_lengths < batch_size * max_conv_length: pad_tensor = utterance_encoding.new_full( - (batch_size * max_conv_length - num_positive_lengths, - 2 * self.utterance_encoder_hidden_size), + (batch_size * max_conv_length - num_positive_lengths, 2 * self.utterance_encoder_hidden_size), self.pad_token_idx) utterance_encoding = torch.cat((utterance_encoding, pad_tensor), 0) @@ -108,15 +99,12 @@ def forward(self, context, utterance_lengths, dialog_lengths): :param dialog_lengths: (batch_size) :return context_state: (batch_size, context_encoder_hidden_size) """ - utterance_encoding = self.get_utterance_encoding( - context, utterance_lengths) # (bs, conv_len, 2 * utt_dim) - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( - dialog_lengths) + utterance_encoding = self.get_utterance_encoding(context, utterance_lengths) # (bs, conv_len, 2 * utt_dim) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(dialog_lengths) # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) - packed_sequences = pack_padded_sequence( - sorted_representations, sorted_lengths.cpu(), batch_first=True) + packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) @@ -153,13 +141,10 @@ def forward(self, request, request_lengths, context_state): batch_size, max_utterance_length = request.shape # sort for pack - sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence( - request_lengths) + sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) - # (batch_size, max_utterance_length, embed_dim) - embedded_request = self.embedding(sorted_request) - packed_request = pack_padded_sequence( - embedded_request, sorted_lengths.cpu(), batch_first=True) + embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) + packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( @@ -175,8 +160,7 @@ def forward(self, request, request_lengths, context_state): (batch_size, max_utterance_length - max_request_length, decoder_hidden_size), self.pad_token_idx) sorted_vocab_state = torch.cat((sorted_vocab_state, pad_tensor), dim=1) # (batch_size, max_utterance_length, decoder_hidden_size) - # (batch_size, max_utterance_length, vocab_size) - sorted_language_output = self.out(sorted_vocab_state) + sorted_language_output = self.out(sorted_vocab_state) # (batch_size, max_utterance_length, vocab_size) # expand context to each time step expanded_sorted_context_state = sorted_context_state.unsqueeze(1).expand( @@ -185,15 +169,12 @@ def forward(self, request, request_lengths, context_state): # compute switch switch_input = torch.cat((expanded_sorted_context_state, sorted_vocab_state), dim=2) # (batch_size, max_utterance_length, context_size + decoder_hidden_size) - # (batch_size, max_utterance_length, 1) - switch = self.switch(switch_input) + switch = self.switch(switch_input) # (batch_size, max_utterance_length, 1) sorted_output = torch.cat(( - F.logsigmoid(switch) + - F.log_softmax(sorted_language_output, dim=2), + F.logsigmoid(switch) + F.log_softmax(sorted_language_output, dim=2), F.logsigmoid(-switch) # for item ), dim=2) - # (batch_size, max_utterance_length, vocab_size + 1) - output = sorted_output.index_select(0, rev_idx) + output = sorted_output.index_select(0, rev_idx) # (batch_size, max_utterance_length, vocab_size + 1) return output diff --git a/crslab/model/crs/redial/redial_conv.py b/crslab/model/crs/redial/redial_conv.py index b7a9529..8e062fb 100644 --- a/crslab/model/crs/redial/redial_conv.py +++ b/crslab/model/crs/redial/redial_conv.py @@ -19,9 +19,9 @@ """ import torch -from crslab.model.base import BaseModel from torch import nn +from crslab.model.base import BaseModel from .modules import HRNN, SwitchingDecoder @@ -59,10 +59,10 @@ def __init__(self, opt, device, vocab, side_data): """ # dataset self.vocab_size = vocab['vocab_size'] - self.pad_token_idx = vocab['special_token_idx']['pad'] - self.start_token_idx = vocab['special_token_idx']['start'] - self.end_token_idx = vocab['special_token_idx']['end'] - self.unk_token_idx = vocab['special_token_idx']['unk'] + self.pad_token_idx = vocab['pad'] + self.start_token_idx = vocab['start'] + self.end_token_idx = vocab['end'] + self.unk_token_idx = vocab['unk'] self.pretrained_embedding = side_data.get('embedding', None) self.embedding_dim = opt.get('embedding_dim', None) if opt.get('embedding', None) and self.embedding_dim is None: diff --git a/crslab/model/crs/redial/redial_rec.py b/crslab/model/crs/redial/redial_rec.py index 3a0ad46..4bbc289 100644 --- a/crslab/model/crs/redial/redial_rec.py +++ b/crslab/model/crs/redial/redial_rec.py @@ -19,6 +19,7 @@ """ import torch.nn as nn + from crslab.model.base import BaseModel @@ -44,7 +45,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.n_entity = vocab['n_entity'] self.layer_sizes = opt['autorec_layer_sizes'] - self.pad_entity_idx = vocab['special_token_idx']['pad_entity'] + self.pad_entity_idx = vocab['pad_entity'] super(ReDialRecModel, self).__init__(opt, device) diff --git a/crslab/model/crs/tgredial/tg_conv.py b/crslab/model/crs/tgredial/tg_conv.py index 0fc85f4..9e505d5 100644 --- a/crslab/model/crs/tgredial/tg_conv.py +++ b/crslab/model/crs/tgredial/tg_conv.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" TGReDial_Conv ============= @@ -23,16 +18,21 @@ """ +import os import torch -from crslab.model.base import BaseModel from torch.nn import CrossEntropyLoss from transformers import GPT2LMHeadModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class TGConvModel(BaseModel): """ - + Attributes: context_truncate: A integer indicating the length of dialogue context. response_truncate: A integer indicating the length of dialogue response. @@ -52,10 +52,12 @@ def __init__(self, opt, device, vocab, side_data): """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] - self.pad_id = vocab['special_token_idx']['pad'] + self.pad_id = vocab['pad'] - self.dpath = opt['conv_pretrained_path'] - super(TGConvModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['gpt2'][language] + dpath = os.path.join(PRETRAIN_PATH, 'gpt2', language) + super(TGConvModel, self).__init__(opt, device, dpath, resource) def build_model(self): """build model""" @@ -94,8 +96,7 @@ def generate(self, context): context = context[..., -self.response_truncate + 1:] for i in range(self.response_truncate - 1): - outputs = self.model( - context, former_hidden_state) # (bs, c_t, v_s), + outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s) @@ -128,10 +129,8 @@ def generate_bs(self, context, beam=4): next_token_logits = last_hidden_state[:, -1, :] next_token_probs = torch.nn.functional.softmax(next_token_logits) topk = torch.topk(next_token_probs, beam, dim=-1) - probs = topk.values.reshape( - [batch_size, -1, beam]) # (bs, candidate, beam) - preds = topk.indices.reshape( - [batch_size, -1, beam]) # (bs, candidate, beam) + probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam) + preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam) for j in range(batch_size): all_candidates = [] @@ -143,8 +142,7 @@ def generate_bs(self, context, beam=4): seq_tmp.append(preds[j][n][k]) candidate = [seq_tmp, prob * probs[j][n][k]] all_candidates.append(candidate) - ordered = sorted( - all_candidates, key=lambda tup: tup[1], reverse=True) + ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) sequences[j] = ordered[:beam] res = [] diff --git a/crslab/model/crs/tgredial/tg_policy.py b/crslab/model/crs/tgredial/tg_policy.py index 37c21e8..708b7f9 100644 --- a/crslab/model/crs/tgredial/tg_policy.py +++ b/crslab/model/crs/tgredial/tg_policy.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" TGReDial_Policy =============== @@ -23,12 +18,17 @@ """ +import os import torch -from crslab.model.base import BaseModel from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class TGPolicyModel(BaseModel): def __init__(self, opt, device, vocab, side_data): @@ -39,13 +39,15 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - self.dpath = opt['policy_pretrained_path'] - super(TGPolicyModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(TGPolicyModel, self).__init__(opt, device, dpath, resource) def build_model(self, *args, **kwargs): """build model""" @@ -72,13 +74,11 @@ def forward(self, batch, mode): tp_mask).pooler_output # (bs, hidden_size) bs = user_profile.shape[0] // self.n_sent - profile_rep = self.profile_bert( - user_profile, profile_mask).pooler_output # (bs, word_num, hidden) + profile_rep = self.profile_bert(user_profile, profile_mask).pooler_output # (bs, word_num, hidden) profile_rep = profile_rep.view(bs, self.n_sent, -1) profile_rep = torch.mean(profile_rep, dim=1) # (bs, hidden) - # [bs, hidden_size*3] - state_rep = torch.cat((context_rep, topic_rep, profile_rep), dim=1) + state_rep = torch.cat((context_rep, topic_rep, profile_rep), dim=1) # [bs, hidden_size*3] topic_scores = self.state2topic_id(state_rep) topic_loss = self.loss(topic_scores, y) diff --git a/crslab/model/crs/tgredial/tg_rec.py b/crslab/model/crs/tgredial/tg_rec.py index 36ce1d4..a02ac5b 100644 --- a/crslab/model/crs/tgredial/tg_rec.py +++ b/crslab/model/crs/tgredial/tg_rec.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @Email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" TGReDial_Rec ============ @@ -23,18 +18,23 @@ """ +import os import torch -from crslab.model.base import BaseModel -from crslab.model.recommendation.sasrec.modules import SASRec from loguru import logger from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources +from crslab.model.recommendation.sasrec.modules import SASRec + class TGRecModel(BaseModel): """ - + Attributes: hidden_dropout_prob: A float indicating the dropout rate to dropout hidden state in SASRec. initializer_range: A float indicating the range of parameters initization in SASRec. @@ -68,8 +68,10 @@ def __init__(self, opt, device, vocab, side_data): self.hidden_act = opt['hidden_act'] self.num_hidden_layers = opt['num_hidden_layers'] - self.dpath = opt['rec_pretrained_path'] - super(TGRecModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(TGRecModel, self).__init__(opt, device, dpath, resource) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters @@ -94,8 +96,7 @@ def forward(self, batch, mode): bert_embed = self.bert(context, attention_mask=mask).pooler_output - # bs, max_len, hidden_size2 - sequence_output = self.SASREC(input_ids, input_mask) + sequence_output = self.SASREC(input_ids, input_mask) # bs, max_len, hidden_size2 sas_embed = sequence_output[:, -1, :] # bs, hidden_size2 embed = torch.cat((sas_embed, bert_embed), dim=1) diff --git a/crslab/model/policy/conv_bert/conv_bert.py b/crslab/model/policy/conv_bert/conv_bert.py index 374014f..76101cc 100644 --- a/crslab/model/policy/conv_bert/conv_bert.py +++ b/crslab/model/policy/conv_bert/conv_bert.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" Conv_BERT ========= @@ -23,11 +18,16 @@ """ +import os -from crslab.model.base import BaseModel from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from ...pretrained_models import resources + class ConvBERTModel(BaseModel): """ @@ -45,11 +45,13 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] - self.dpath = opt['policy_pretrained_path'] - super(ConvBERTModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(ConvBERTModel, self).__init__(opt, device, dpath, resource) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/pmi/pmi.py b/crslab/model/policy/pmi/pmi.py index 61e198d..a406cb3 100644 --- a/crslab/model/policy/pmi/pmi.py +++ b/crslab/model/policy/pmi/pmi.py @@ -15,6 +15,7 @@ from collections import defaultdict import torch + from crslab.model.base import BaseModel @@ -38,7 +39,7 @@ def __init__(self, opt, device, vocab, side_data): """ self.topic_class_num = vocab['n_topic'] - self.pad_topic = vocab['special_token_idx']['pad_topic'] + self.pad_topic = vocab['pad_topic'] super(PMIModel, self).__init__(opt, device) def build_model(self, *args, **kwargs): diff --git a/crslab/model/policy/profile_bert/profile_bert.py b/crslab/model/policy/profile_bert/profile_bert.py index acce39a..65b400f 100644 --- a/crslab/model/policy/profile_bert/profile_bert.py +++ b/crslab/model/policy/profile_bert/profile_bert.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" Profile_BERT ============ @@ -23,12 +18,17 @@ """ +import os import torch -from crslab.model.base import BaseModel from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class ProfileBERTModel(BaseModel): """ @@ -47,13 +47,15 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] self.n_sent = opt.get('n_sent', 10) - self.dpath = opt['policy_pretrained_path'] - super(ProfileBERTModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(ProfileBERTModel, self).__init__(opt, device, dpath, resource) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/policy/topic_bert/topic_bert.py b/crslab/model/policy/topic_bert/topic_bert.py index cb47bc0..400eaeb 100644 --- a/crslab/model/policy/topic_bert/topic_bert.py +++ b/crslab/model/policy/topic_bert/topic_bert.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" Topic_BERT ========== @@ -23,11 +18,16 @@ """ +import os -from crslab.model.base import BaseModel from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class TopicBERTModel(BaseModel): """ @@ -46,12 +46,14 @@ def __init__(self, opt, device, vocab, side_data): device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. - + """ self.topic_class_num = vocab['n_topic'] - self.dpath = opt['policy_pretrained_path'] - super(TopicBERTModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + resource = resources['bert'][language] + super(TopicBERTModel, self).__init__(opt, device, dpath, resource) def build_model(self, *args, **kwargs): """build model""" diff --git a/crslab/model/pretrained_models.py b/crslab/model/pretrained_models.py new file mode 100644 index 0000000..e254bdf --- /dev/null +++ b/crslab/model/pretrained_models.py @@ -0,0 +1,64 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/1/6 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2021/1/7 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +"""Download links of pretrain models. + +Now we provide the following models: + +- `BERT`_: zh, en +- `GPT2`_: zh, en + +.. _BERT: + https://www.aclweb.org/anthology/N19-1423/ +.. _GPT2: + https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf + +""" + +resources = { + 'bert': { + 'zh': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', + 'bert_zh.zip', + 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' + ) + }, + 'en': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', + 'bert_en.zip', + '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' + ) + }, + }, + 'gpt2': { + 'zh': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', + 'gpt2_zh.zip', + '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' + ) + }, + 'en': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', + 'gpt2_en.zip', + '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' + ) + } + } +} diff --git a/crslab/model/recommendation/bert/bert.py b/crslab/model/recommendation/bert/bert.py index 24b6670..cb78a7b 100644 --- a/crslab/model/recommendation/bert/bert.py +++ b/crslab/model/recommendation/bert/bert.py @@ -7,11 +7,6 @@ # @Author : Xiaolei Wang, Yuanhang Zhou # @email : wxl1999@foxmail.com, sdzyh002@gmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - r""" BERT ==== @@ -23,12 +18,17 @@ """ +import os -from crslab.model.base import BaseModel from loguru import logger from torch import nn from transformers import BertModel +from crslab.config import PRETRAIN_PATH +from crslab.data import dataset_language_map +from crslab.model.base import BaseModel +from crslab.model.pretrained_models import resources + class BERTModel(BaseModel): """ @@ -50,8 +50,10 @@ def __init__(self, opt, device, vocab, side_data): """ self.item_size = vocab['n_entity'] - self.dpath = opt['rec_pretrained_path'] - super(BERTModel, self).__init__(opt, device, self.dpath) + language = dataset_language_map[opt['dataset']] + resource = resources['bert'][language] + dpath = os.path.join(PRETRAIN_PATH, "bert", language) + super(BERTModel, self).__init__(opt, device, dpath, resource) def build_model(self): # build BERT layer, give the architecture, load pretrained parameters diff --git a/crslab/quick_start/__init__.py b/crslab/quick_start/__init__.py index 12b84a5..a03ec24 100644 --- a/crslab/quick_start/__init__.py +++ b/crslab/quick_start/__init__.py @@ -1 +1 @@ -from .quick_start import run_crslab +from .quick_start import quick_start \ No newline at end of file diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index b37f63b..fa31de2 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -8,11 +8,12 @@ # @Author : Xiaolei Wang # @email : wxl1999@foxmail.com -from crslab.data import get_dataloader, get_dataset, get_tokenizer +from crslab.config import Config +from crslab.data import get_dataset, get_dataloader from crslab.system import get_system -def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False, +def quick_start(config, mode, save_data=False, restore_data=False, save_system=False, restore_system=False, interact=False, debug=False, tensorboard=False): """A fast running api, which includes the complete process of training and testing models on specified datasets. @@ -33,16 +34,12 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r """ # dataset & dataloader if isinstance(config['tokenize'], str): - CRS_tokenizer = get_tokenizer(config['tokenize'], path=None) - CRS_dataset = get_dataset( - config, config['tokenize'], CRS_tokenizer, restore_data, save_data) + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) side_data = CRS_dataset.side_data vocab = CRS_dataset.vocab - train_dataloader = get_dataloader( - config, CRS_dataset.train_data, vocab) - valid_dataloader = get_dataloader( - config, CRS_dataset.valid_data, vocab) + train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab) + valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab) test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab) else: tokenized_dataset = {} @@ -56,13 +53,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r if tokenize in tokenized_dataset: dataset = tokenized_dataset[tokenize] else: - task_tokenize_path = str(task) + '_tokenize_path' - tokenize_path = None - if task_tokenize_path in config: - tokenize_path = config[task_tokenize_path] - CRS_tokenizer = get_tokenizer(tokenize, tokenize_path) - dataset = get_dataset( - config, tokenize, CRS_tokenizer, restore_data, save_data) + dataset = get_dataset(config, tokenize, restore_data, save_data) tokenized_dataset[tokenize] = dataset train_data = dataset.train_data valid_data = dataset.valid_data @@ -70,18 +61,15 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r side_data[task] = dataset.side_data vocab[task] = dataset.vocab - train_dataloader[task] = get_dataloader( - config, train_data, vocab[task]) - valid_dataloader[task] = get_dataloader( - config, valid_data, vocab[task]) - test_dataloader[task] = get_dataloader( - config, test_data, vocab[task]) + train_dataloader[task] = get_dataloader(config, train_data, vocab[task]) + valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task]) + test_dataloader[task] = get_dataloader(config, test_data, vocab[task]) # system CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) if interact: CRS.interact() else: - CRS.fit() + CRS.fit(mode) if save_system: CRS.save_model() diff --git a/crslab/system/__init__.py b/crslab/system/__init__.py index dbbd699..5c66c85 100644 --- a/crslab/system/__init__.py +++ b/crslab/system/__init__.py @@ -21,6 +21,7 @@ from .redial import ReDialSystem from .ntrd import NTRDSystem from .tgredial import TGReDialSystem +from .chatgpt import ChatGPTSystem system_register_table = { 'ReDialRec_ReDialConv': ReDialSystem, @@ -41,7 +42,8 @@ 'GRU4REC': TGReDialSystem, 'Popularity': TGReDialSystem, 'TextCNN': TGReDialSystem, - 'NTRD': NTRDSystem + 'NTRD': NTRDSystem, + 'ChatGPT': ChatGPTSystem } diff --git a/crslab/system/chatgpt.py b/crslab/system/chatgpt.py new file mode 100644 index 0000000..d9a06ef --- /dev/null +++ b/crslab/system/chatgpt.py @@ -0,0 +1,199 @@ +# @Time : 2023/6/14 +# @Author : Xinyu Tang +# @Email : txy20010310@163.com + +import os +import json +import random + +import openai +from tqdm import tqdm +from loguru import logger +from tenacity import _utils, Retrying, retry_if_not_exception_type +from tenacity.stop import stop_base +from tenacity.wait import wait_base + +from crslab.config import DATASET_PATH, SAVE_PATH +from crslab.system.base import BaseSystem +from crslab.data.dataset import BaseDataset +from crslab.data import dataset_register_table +from crslab.model import Model_register_table +from crslab.evaluator.chat import my_wait_exponential, my_stop_after_attempt, Chat +from crslab.evaluator.ask import Ask +from crslab.evaluator.rec import RecEvaluator +from crslab.system.utils.functions import get_exist_item_set + +def get_exist_dialog_set(save_dir): + exist_id_set = set() + for file in os.listdir(save_dir): + file_id = os.path.splitext(file)[0] + exist_id_set.add(file_id) + return exist_id_set + +class ChatGPTSystem(BaseSystem): + + def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False, interact=False, debug=False, tensorboard=False): + super(ChatGPTSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) + + openai.api_key = opt['api_key'] + self.dataset = opt['dataset'] + self.dpath = os.path.join(DATASET_PATH, opt['dataset']) + self.embed_save_dir = os.path.join(SAVE_PATH, self.dataset, 'embed') + os.makedirs(self.embed_save_dir, exist_ok=True) + self.chat_save_dir = os.path.join(SAVE_PATH, self.dataset, 'chat') + os.makedirs(self.chat_save_dir, exist_ok=True) + self.ask_save_dir = os.path.join(SAVE_PATH, self.dataset, 'ask') + os.makedirs(self.ask_save_dir, exist_ok=True) + with open(os.path.join(self.dpath, 'id2info.json'), 'r', encoding='utf-8') as f: + self.id2info = json.load(f) + self.test_dataloader = test_dataloader + self.rec_optim_opt = opt['rec'] + self.conv_optim_opt = opt['conv'] + self.cache_item_opt = opt['cache_item'] + self.rec_batch_size = self.rec_optim_opt['batch_size'] + self.conv_batch_size = self.conv_optim_opt['batch_size'] + self.cache_item_batch_size= self.cache_item_opt['batch_size'] + self.api_key = opt['api_key'] + self.turn_num = opt['turn_num'] + self.dataset_class = dataset_register_table[self.dataset](opt, opt['tokenize']) + crs_model_name = opt['model_name'] + self.crs_model = Model_register_table[crs_model_name](opt, opt['device']) + + def my_before_sleep(self, retry_state): + logger.debug(f'Retrying: attempt {retry_state.attempt_number} ended with: {retry_state.outcome}, spend {retry_state.seconds_since_start} in total') + + def annotate(self, item_text_list): + request_timeout = 6 + for attempt in Retrying( + reraise=True, retry=retry_if_not_exception_type((openai.error.InvalidRequestError, openai.error.AuthenticationError)), + wait=my_wait_exponential(min=1, max=60), stop=(my_stop_after_attempt(8)), before_sleep=self.my_before_sleep + ): + with attempt: + response = openai.Embedding.create( + model='text-embedding-ada-002', input=item_text_list, request_timeout=request_timeout + ) + request_timeout = min(30, request_timeout * 2) + + return response + + def cache_item(self): + attr_list = self.dataset_class.get_attr_list() + id2text = {} + for item_id, info_dict in self.id2info.items(): + attr_str_list = [f'Title: {info_dict["name"]}'] + for attr in attr_list: + if attr not in info_dict: + continue + if isinstance(info_dict[attr], list): + value_str = ', '.join(info_dict[attr]) + else: + value_str = info_dict[attr] + attr_str_list.append(f'{attr.capitalize()}: {value_str}') + item_text = '; '.join(attr_str_list) + id2text[item_id] = item_text + + item_ids = set(self.id2info.keys()) - get_exist_item_set(self.embed_save_dir) + while len(item_ids) > 0: + logger.info(len(item_ids)) + batch_item_ids = random.sample(tuple(item_ids), min(self.cache_item_batch_size, len(item_ids))) + batch_texts = [id2text[item_id] for item_id in batch_item_ids] + + batch_embeds = self.annotate(batch_texts)['data'] + for embed in batch_embeds: + item_id = batch_item_ids[embed['index']] + with open(f'{self.embed_save_dir}/{item_id}.json', 'w', encoding='utf-8') as f: + json.dump(embed['embedding'], f, ensure_ascii=False) + + item_ids -= get_exist_item_set(self.embed_save_dir) + + + def iEvaLM_chat(self): + logger.info('[Test]') + iEvaLM_CHAT = Chat(self.turn_num, self.crs_model, self.dataset) + dataid2data = {} + for i, batch in enumerate(self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False)): + dialog_ids = batch['dialog_id'] + for dialog_id in dialog_ids: + dataid2data[dialog_id] = batch + + dialog_id_set = set(dataid2data.keys()) - get_exist_dialog_set(self.chat_save_dir) + while len(dialog_id_set) > 0: + logger.info(len(dialog_id_set)) + dialog_id = random.choice(tuple(dialog_id_set)) + batch_data = dataid2data[dialog_id] + returned_data = iEvaLM_CHAT.chat(batch_data, self.turn_num) + + with open(f'{self.chat_save_dir}/{dialog_id}.json', 'w', encoding='utf-8') as f: + json.dump(returned_data, f, ensure_ascii=False, indent=2) + + dialog_id_set -= get_exist_dialog_set(self.chat_save_dir) + + def iEvaLM_ask(self): + logger.info('[Test]') + ask_instruction_dict = self.dataset_class.get_ask_instruction() + iEvaLM_ASK = Ask(self.turn_num, self.crs_model, self.dataset, ask_instruction_dict) + dataid2data = {} + for i, batch in enumerate(self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False)): + dialog_ids = batch['dialog_id'] + for dialog_id in dialog_ids: + dataid2data[dialog_id] = batch + + dialog_id_set = set(dataid2data.keys()) - get_exist_dialog_set(self.ask_save_dir) + while len(dialog_id_set) > 0: + logger.info(len(dialog_id_set)) + dialog_id = random.choice(tuple(dialog_id_set)) + batch_data = dataid2data[dialog_id] + returned_data = iEvaLM_ASK.ask(batch_data, self.turn_num) + + with open(f'{self.ask_save_dir}/{dialog_id}.json', 'w', encoding='utf-8') as f: + json.dump(returned_data, f, ensure_ascii=False, indent=2) + + dialog_id_set -= get_exist_dialog_set(self.ask_save_dir) + + def step(self, batch, stage, mode): + return super().step(batch, stage, mode) + + def evaluate_iEvaLM(self, mode): + metric = RecEvaluator(k_list=[1, 10, 25, 50]) + persuatiness_list = [] + if mode == 'chat': + save_path = self.chat_save_dir + elif mode == 'ask': + save_path = self.ask_save_dir + if os.path.exists(save_path) and len(os.listdir(save_path)) > 0: + path_list = os.listdir(save_path) + print(save_path, len(path_list)) + + for path in tqdm(path_list): + with open(f"{save_path}/{path}", 'r', encoding="utf-8") as f: + context_list = json.load(f) + if mode == 'chat': + persuasiveness_score = context_list[-1]['persuasiveness_score'] + persuatiness_list.append(float(persuasiveness_score)) + # TODOļ¼š modify chatgpt evaluator + for context in context_list[::-1]: + if 'rec_items' in context: + rec_labels = context['rec_items'] + rec_items = context['pred_items'] + for rec_label in rec_labels: + metric.rec_evaluate(rec_items, rec_label) + break + + metric.report(mode='test') + if mode == 'chat': + avg_persuatiness_score = sum(persuatiness_list) / len(persuatiness_list) + logger.info(avg_persuatiness_score) + + def fit(self, mode): + self.cache_item() + if mode == 'chat': + self.iEvaLM_chat() + elif mode == 'ask': + self.iEvaLM_ask() + else: + raise ValueError(f'Invalid mode: {mode}') + self.evaluate_iEvaLM(mode) + + def interact(self): + pass + \ No newline at end of file diff --git a/crslab/system/inspired.py b/crslab/system/inspired.py index e827ee3..b9219d1 100644 --- a/crslab/system/inspired.py +++ b/crslab/system/inspired.py @@ -2,14 +2,15 @@ # @Author : Beichen Zhang # @Email : zhangbeichen724@gmail.com +import torch +from loguru import logger from math import floor -import torch +from crslab.data import dataset_language_map from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt -from loguru import logger class InspiredSystem(BaseSystem): @@ -38,7 +39,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc if hasattr(self, 'conv_model'): self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['special_token_idx']['end'] + self.end_token_idx = vocab['conv']['end'] if hasattr(self, 'rec_model'): self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] @@ -53,14 +54,15 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.conv_epoch = self.conv_optim_opt['epoch'] self.conv_batch_size = self.conv_optim_opt['batch_size'] if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][ - 'name']: + 'name']: batch_num = 0 for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False): batch_num += 1 - conv_training_steps = self.conv_epoch * \ - floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) + conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps + self.language = dataset_language_map[self.opt['dataset']] + def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -157,8 +159,7 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + \ - self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test diff --git a/crslab/system/kbrd.py b/crslab/system/kbrd.py index 5579da8..085eda4 100644 --- a/crslab/system/kbrd.py +++ b/crslab/system/kbrd.py @@ -11,11 +11,12 @@ import os import torch +from loguru import logger + from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt -from loguru import logger class KBRDSystem(BaseSystem): @@ -42,7 +43,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc restore_system, interact, debug, tensorboard) self.ind2tok = vocab['ind2tok'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.end_token_idx = vocab['end'] self.item_ids = side_data['item_entity_ids'] self.rec_optim_opt = opt['rec'] diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 047ff4f..7f7b2a6 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -7,19 +7,15 @@ # @Author : Kun Zhou, Xiaolei Wang # @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - import os import torch +from loguru import logger + from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt -from loguru import logger class KGSFSystem(BaseSystem): @@ -46,7 +42,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc restore_system, interact, debug, tensorboard) self.ind2tok = vocab['ind2tok'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.end_token_idx = vocab['end'] self.item_ids = side_data['item_entity_ids'] self.pretrain_optim_opt = self.opt['pretrain'] @@ -84,11 +80,9 @@ def step(self, batch, stage, mode): if info_loss is not None: self.backward(info_loss.sum()) info_loss = info_loss.sum().item() - self.evaluator.optim_metrics.add( - "info_loss", AverageMetric(info_loss)) + self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == 'rec': - rec_loss, info_loss, rec_predict = self.model.forward( - batch, stage, mode) + rec_loss, info_loss, rec_predict = self.model.forward(batch, stage, mode) if info_loss: loss = rec_loss + 0.025 * info_loss else: @@ -98,12 +92,10 @@ def step(self, batch, stage, mode): else: self.rec_evaluate(rec_predict, batch[-1]) rec_loss = rec_loss.sum().item() - self.evaluator.optim_metrics.add( - "rec_loss", AverageMetric(rec_loss)) + self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) if info_loss: info_loss = info_loss.sum().item() - self.evaluator.optim_metrics.add( - "info_loss", AverageMetric(info_loss)) + self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == "conv": if mode != "test": gen_loss, pred = self.model.forward(batch, stage, mode) @@ -112,8 +104,7 @@ def step(self, batch, stage, mode): else: self.conv_evaluate(pred, batch[-1]) gen_loss = gen_loss.sum().item() - self.evaluator.optim_metrics.add( - "gen_loss", AverageMetric(gen_loss)) + self.evaluator.optim_metrics.add("gen_loss", AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: pred = self.model.forward(batch, stage, mode) @@ -149,8 +140,7 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + \ - self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test @@ -164,8 +154,6 @@ def train_recommender(self): def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() - elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: - self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) diff --git a/crslab/system/ntrd.py b/crslab/system/ntrd.py index ffbc1dc..ab1e923 100644 --- a/crslab/system/ntrd.py +++ b/crslab/system/ntrd.py @@ -3,15 +3,16 @@ # @Email : oran_official@outlook.com import os +from crslab.evaluator.metrics import gen +from numpy.core.numeric import NaN import torch -from crslab.evaluator.metrics import gen +from loguru import logger + from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem -from crslab.system.utils.functions import ind2slot, ind2txt_with_slots -from loguru import logger -from numpy.core.numeric import NaN +from crslab.system.utils.functions import ind2slot,ind2txt_with_slots class NTRDSystem(BaseSystem): @@ -24,7 +25,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.ind2tok = vocab['ind2tok'] self.ind2movie = vocab['id2entity'] - self.end_token_idx = vocab['special_token_idx']['end'] + self.end_token_idx = vocab['end'] self.item_ids = side_data['item_entity_ids'] self.pretrain_optim_opt = self.opt['pretrain'] diff --git a/crslab/system/redial.py b/crslab/system/redial.py index 2ce257e..0276e90 100644 --- a/crslab/system/redial.py +++ b/crslab/system/redial.py @@ -8,11 +8,13 @@ # @email : wxl1999@foxmail.com import torch +from loguru import logger + +from crslab.data import dataset_language_map from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt -from loguru import logger class ReDialSystem(BaseSystem): @@ -38,7 +40,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc super(ReDialSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['special_token_idx']['end'] + self.end_token_idx = vocab['conv']['end'] self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] @@ -49,6 +51,8 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.rec_batch_size = self.rec_optim_opt['batch_size'] self.conv_batch_size = self.conv_optim_opt['batch_size'] + self.language = dataset_language_map[self.opt['dataset']] + def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -83,8 +87,7 @@ def step(self, batch, stage, mode): else: self.rec_evaluate(rec_scores, batch['item']) rec_loss = rec_loss.item() - self.evaluator.optim_metrics.add( - "rec_loss", AverageMetric(rec_loss)) + self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) else: gen_loss, preds = self.conv_model.forward(batch, mode=mode) gen_loss = gen_loss.sum() @@ -93,8 +96,7 @@ def step(self, batch, stage, mode): else: self.conv_evaluate(preds, batch['response']) gen_loss = gen_loss.item() - self.evaluator.optim_metrics.add( - 'gen_loss', AverageMetric(gen_loss)) + self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss)) self.evaluator.gen_metrics.add('ppl', PPLMetric(gen_loss)) def train_recommender(self): @@ -106,16 +108,14 @@ def train_recommender(self): logger.info('[Train]') for batch in self.train_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size): self.step(batch, stage='rec', mode='train') - # report train loss - self.evaluator.report(epoch=epoch, mode='train') + self.evaluator.report(epoch=epoch, mode='train') # report train loss # val logger.info('[Valid]') with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.valid_dataloader['rec'].get_rec_data(batch_size=self.rec_batch_size, shuffle=False): self.step(batch, stage='rec', mode='valid') - # report valid loss - self.evaluator.report(epoch=epoch, mode='valid') + self.evaluator.report(epoch=epoch, mode='valid') # report valid loss # early stop metric = self.evaluator.optim_metrics['rec_loss'] if self.early_stop(metric): diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 9d5d390..3aaaa7b 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -7,22 +7,18 @@ # @Author : Xiaolei Wang # @Email : wxl1999@foxmail.com -# UPDATE: -# @Time : 2022/9/28 -# @Author : Xinyu Tang -# @Email : txy20010310@163.com - import os -from math import floor import torch +from loguru import logger +from math import floor + from crslab.config import PRETRAIN_PATH -from crslab.data import dataset_language_map, get_dataloader +from crslab.data import get_dataloader, dataset_language_map from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem from crslab.system.utils.functions import ind2txt -from loguru import logger class TGReDialSystem(BaseSystem): @@ -51,7 +47,7 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc if hasattr(self, 'conv_model'): self.ind2tok = vocab['conv']['ind2tok'] - self.end_token_idx = vocab['conv']['special_token_idx']['end'] + self.end_token_idx = vocab['conv']['end'] if hasattr(self, 'rec_model'): self.item_ids = side_data['rec']['item_entity_ids'] self.id2entity = vocab['rec']['id2entity'] @@ -66,12 +62,11 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.conv_epoch = self.conv_optim_opt['epoch'] self.conv_batch_size = self.conv_optim_opt['batch_size'] if self.conv_optim_opt.get('lr_scheduler', None) and 'Transformers' in self.conv_optim_opt['lr_scheduler'][ - 'name']: + 'name']: batch_num = 0 for _ in self.train_dataloader['conv'].get_conv_data(batch_size=self.conv_batch_size, shuffle=False): batch_num += 1 - conv_training_steps = self.conv_epoch * \ - floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) + conv_training_steps = self.conv_epoch * floor(batch_num / self.conv_optim_opt.get('update_freq', 1)) self.conv_optim_opt['lr_scheduler']['training_steps'] = conv_training_steps if hasattr(self, 'policy_model'): @@ -126,8 +121,7 @@ def step(self, batch, stage, mode): else: self.policy_model.eval() - policy_loss, policy_predict = self.policy_model.forward( - batch, mode) + policy_loss, policy_predict = self.policy_model.forward(batch, mode) if mode == "train" and policy_loss is not None: policy_loss = policy_loss.sum() self.backward(policy_loss) @@ -175,11 +169,8 @@ def train_recommender(self): if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) - elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: - bert_param = list(self.rec_model.bert.named_parameters()) else: - bert_param = list( - self.rec_model.module.bert.named_parameters()) + bert_param = list(self.rec_model.module.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] else: bert_param = [] @@ -207,8 +198,7 @@ def train_recommender(self): self.step(batch, stage='rec', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + \ - self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test @@ -282,8 +272,7 @@ def train_policy(self): self.step(batch, stage='policy', mode='val') self.evaluator.report(epoch=epoch, mode='val') # early stop - metric = self.evaluator.rec_metrics['hit@1'] + \ - self.evaluator.rec_metrics['hit@50'] + metric = self.evaluator.rec_metrics['hit@1'] + self.evaluator.rec_metrics['hit@50'] if self.early_stop(metric): break # test @@ -318,8 +307,7 @@ def interact(self): for r in rank.tolist(): item_ids.append(self.item_ids[r]) first_item_id = item_ids[:1] - self.update_context( - 'rec', entity_ids=first_item_id, item_ids=first_item_id) + self.update_context('rec', entity_ids=first_item_id, item_ids=first_item_id) print(f"[Recommend]:") for item_id in item_ids: @@ -328,22 +316,18 @@ def interact(self): # conv if hasattr(self, 'conv_model'): conv_input = self.process_input(input_text, 'conv') - preds = self.conv_model.forward( - conv_input, 'infer').tolist()[0] + preds = self.conv_model.forward(conv_input, 'infer').tolist()[0] p_str = ind2txt(preds, self.ind2tok, self.end_token_idx) - token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id( - p_str, 'conv') - self.update_context('conv', token_ids, - entity_ids, movie_ids, word_ids) + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(p_str, 'conv') + self.update_context('conv', token_ids, entity_ids, movie_ids, word_ids) print(f"[Response]:\n{p_str}") # input input_text = self.get_input(self.language) def process_input(self, input_text, stage): - token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id( - input_text, stage) + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(input_text, stage) self.update_context(stage, token_ids, entity_ids, movie_ids, word_ids) data = {'role': 'Seeker', 'context_tokens': self.context[stage]['context_tokens'], @@ -358,8 +342,7 @@ def process_input(self, input_text, stage): elif stage == 'conv': data = dataloader.conv_interact(data) - data = [ele.to(self.device) if isinstance( - ele, torch.Tensor) else ele for ele in data] + data = [ele.to(self.device) if isinstance(ele, torch.Tensor) else ele for ele in data] return data def convert_to_id(self, text, stage): @@ -370,23 +353,18 @@ def convert_to_id(self, text, stage): else: raise - entities = self.link( - tokens, self.side_data[stage]['entity_kg']['entity']) + entities = self.link(tokens, self.side_data[stage]['entity_kg']['entity']) words = self.link(tokens, self.side_data[stage]['word_kg']['entity']) if self.opt['tokenize'][stage] in ('gpt2', 'bert'): language = dataset_language_map[self.opt['dataset']] - path = os.path.join( - PRETRAIN_PATH, self.opt['tokenize'][stage], language) + path = os.path.join(PRETRAIN_PATH, self.opt['tokenize'][stage], language) tokens = self.tokenize(text, 'bert', path) - token_ids = [self.vocab[stage]['tok2ind'].get( - token, self.vocab[stage]['unk']) for token in tokens] + token_ids = [self.vocab[stage]['tok2ind'].get(token, self.vocab[stage]['unk']) for token in tokens] entity_ids = [self.vocab[stage]['entity2id'][entity] for entity in entities if entity in self.vocab[stage]['entity2id']] - movie_ids = [ - entity_id for entity_id in entity_ids if entity_id in self.item_ids] - word_ids = [self.vocab[stage]['word2id'][word] - for word in words if word in self.vocab[stage]['word2id']] + movie_ids = [entity_id for entity_id in entity_ids if entity_id in self.item_ids] + word_ids = [self.vocab[stage]['word2id'][word] for word in words if word in self.vocab[stage]['word2id']] return token_ids, entity_ids, movie_ids, word_ids diff --git a/crslab/system/utils/functions.py b/crslab/system/utils/functions.py index a622f36..630797f 100644 --- a/crslab/system/utils/functions.py +++ b/crslab/system/utils/functions.py @@ -12,6 +12,7 @@ # @Author : Zhipeng Zhao # @email : oran_official@outlook.com +import os import torch @@ -64,3 +65,10 @@ def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',s def ind2slot(inds,ind2slot): return [ ind2slot[ind] for ind in inds] + +def get_exist_item_set(save_dir): + exist_item_set = set() + for file in os.listdir(save_dir): + user_id = os.path.splitext(file)[0] + exist_item_set.add(user_id) + return exist_item_set \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c133145..f7fba73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,3 @@ requests~=2.25.1 scikit-learn~=0.24.0 fuzzywuzzy~=0.18.0 tensorboard~=2.4.1 -gensim~=4.2.0 diff --git a/run_crslab.py b/run_crslab.py index b0f7f73..4dd4d40 100644 --- a/run_crslab.py +++ b/run_crslab.py @@ -35,10 +35,12 @@ help='interact with your system instead of training') parser.add_argument('-tb', '--tensorboard', action='store_true', help='enable tensorboard to monitor train performance') + parser.add_argument('--mode', choices=['train', 'chat', 'ask'], + help='Train CRS model / Evaluating the CRS model with iEvaLM free-form chit-chat mode / attribute-based question answering mode') args, _ = parser.parse_known_args() config = Config(args.config, args.gpu, args.debug) - from crslab.quick_start import run_crslab - - run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, + from crslab.quick_start import quick_start + + quick_start(config, args.mode, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, args.debug, args.tensorboard)