Skip to content

Commit

Permalink
enrich(rag): add training procedure for rag
Browse files Browse the repository at this point in the history
  • Loading branch information
GaoangLiu committed Nov 18, 2023
1 parent a2f1b4a commit 7ff987f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 5 deletions.
4 changes: 3 additions & 1 deletion _drafts/2022/bm25.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ Dual Encoder 结构(也称 biEncoder),每个 encoder 都是一个 BERT,

训练方式:给定一个样本 $a = \langle q_i, p_i^+, p_1^-, p_2^-, ..., p_n^- \rangle$,其中包含 query $q_i$,正相关答案 $p_i^+$ 及若干个负相关答案 $p_i^-$。训练目标是优化相似度的对数似然,这本质上就是对比学习,在 SimCSE 中也有用到,详情可以参考之前的文章 [《Semtatic Similarity》]({{site.baseurl}}/2022/10/18/Semantic-Similarity/#simcse)

$$ \mathcal{L}(a) = - \log \frac{e^{sim(q_i, p_i^+) / \tau}}{(\sum_{j=1}^{n+1} e^{sim(q_i, p_j^-) / \tau}) + e^{sim(q_i, p_i^+) / \tau}}$$
$$ \mathcal{L}(a) = - \log \frac{e^{sim(q_i, p_i^+) / \tau}}{(\sum_{j=1}^{n} e^{sim(q_i, p_j^-) / \tau}) + e^{sim(q_i, p_i^+) / \tau}}$$

其中 $\tau$ 是温度参数,用于控制相似度的分布,在 SimCSE 中的损失函数中有定义,DPR 中没有设置这个超参,相当于对应的 $\tau$ 为 1。$\tau$ 参数的值越小,相似度分布的“温度”越低,表示对相似度的判别更加严格。意味着在训练时,相似的句子会更有可能被赋予较高的概率,而相似度较低的句子则会有较低的概率。通过调节 $\tau$ 参数,可以影响相似度分布的平滑程度,从而对模型的训练产生影响。

题外话,这个工作 Danqi 大佬也有参与,Danqi 大佬在同年还指导了另一篇类似思想的工作 [DensePhrases](https://arxiv.org/pdf/2012.12624.pdf),主要侧重于 phrase 的检索。

## 负样本采样
训练数据中,正样本由 (Query, Answer) 构成,但负样本的采集就需要人工构造了。DPR 考虑了三种方式:
1. 从语料中随机采取段落;
Expand Down
11 changes: 10 additions & 1 deletion _drafts/2023/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ categories:
ChatGPT 爆火之后,有一段时间内很多公司都在竞相做向量数据库,一些数据库厂商也在竞相在传统数据库上增加向量存储功能。常见的做法是通过预训练获取一个大模型,然后将数据向量化并存储。


关于检索,一个比较引入注目的技术是 RAG (Retrieval-Augmented Generation)。这个技术与 meta 于 2020 年在论文 [《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》](https://arxiv.org/pdf/2005.11401.pdf) 中提出,它是一个检索增强的生成模型,通过检索得到的上下文信息来指导生成,从而提高生成的质量。这里面有两个主要的模块,一个是检索,使用的技术是 DPR(Dense Passage Retrieval),即是 20 年出的暴打前浪 BM25 的技术,同样也是 meta 的工作,这个工作丹琦大佬也有参与。关于 DPR 的结构,我们在之前的文章[《Okapi-BM25》]({{site.baseur}}/2022/11/17/Okapi-BM25/)里稍有提过。另一个模块是 seq2seq 生成器,模型使用的 [BART](https://arxiv.org/abs/1910.13461)(也是 meta 的工作)。
关于检索,一个比较引入注目的技术是 RAG (Retrieval-Augmented Generation)。这个技术与 meta 于 2020 年在论文 [《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》](https://arxiv.org/pdf/2005.11401.pdf) 中被提出,它是一个检索增强的生成模型,通过检索得到的上下文信息来指导生成,从而提高生成的质量。这里面有两个主要的模块,一个是检索,使用的技术是 DPR(Dense Passage Retrieval),即是 20 年推出的“暴打前浪 BM25 的技术,同样也是 meta 的工作。DPR 整体结构是一个 dual encoder: document encoder 和 query encoder,两个 encoder 使用的模型都是 $\text{BERT}_\text{BASE}$。关于 DPR 的机制,我们在之前的文章[《Okapi-BM25》]({{site.baseur}}/2022/11/17/Okapi-BM25/)里稍有提过。另一个模块是 seq2seq 生成器,模型使用的 [BART-large](https://arxiv.org/abs/1910.13461)(也是 meta 的工作)。


# RAG 结构
Expand All @@ -43,6 +43,15 @@ $$\begin{aligned}
p_\text{RAG-token}(y\lvert x) &\approx \prod_{i=1}^N \sum_{z\in \mathcal{Z}} p_\eta(z\lvert x)p_\theta(y_i\lvert x, z, y_{1,...,i-1})
\end{aligned}$$

# 训练
RAG 联合训练检索模块跟生成模块,不需要关于检索文档的监督信息(这也是为什么文中说将 document 视为潜变量),在给定数据集 $\mathcal{D} = \{(x^{(i)}, y^{(i)})\}_{i=1}^N$ 的情况下,最小下面的负对数似然:

$$\begin{aligned}
\mathcal{L}(\theta, \eta) &= - \sum_{(x, y)\in \mathcal{D}} \log p_\text{RAG}(y\lvert x) \\\
&= - \sum_{(x, y)\in \mathcal{D}} \log \sum_{z\in \mathcal{Z}} p_\eta(z\lvert x)p_\theta(y\lvert x, z)
\end{aligned}$$

在训练过程,由于更新检索库的编码器消耗巨大,因为每更新一次文档编码器就需要对所有文档重新编码,所以在训练过程 RAG 选择固定文档编码器 $\text{BERT}_d$ 的参数,只训练 $\text{BERT}_q$ 编码器与 BART 生成器。

# 效果

Expand Down
4 changes: 3 additions & 1 deletion _posts/2022/2022-11-17-Okapi-BM25.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ Dual Encoder 结构(也称 biEncoder),每个 encoder 都是一个 BERT,

训练方式:给定一个样本 $$a = \langle q_i, p_i^+, p_1^-, p_2^-, ..., p_n^- \rangle$$,其中包含 query $$q_i$$,正相关答案 $$p_i^+$$ 及若干个负相关答案 $$p_i^-$$。训练目标是优化相似度的对数似然,这本质上就是对比学习,在 SimCSE 中也有用到,详情可以参考之前的文章 [《Semtatic Similarity》]({{site.baseurl}}/2022/10/18/Semantic-Similarity/#simcse)

$$ \mathcal{L}(a) = - \log \frac{e^{sim(q_i, p_i^+) / \tau}}{(\sum_{j=1}^{n+1} e^{sim(q_i, p_j^-) / \tau}) + e^{sim(q_i, p_i^+) / \tau}}$$
$$ \mathcal{L}(a) = - \log \frac{e^{sim(q_i, p_i^+) / \tau}}{(\sum_{j=1}^{n} e^{sim(q_i, p_j^-) / \tau}) + e^{sim(q_i, p_i^+) / \tau}}$$

其中 $$\tau$$ 是温度参数,用于控制相似度的分布,在 SimCSE 中的损失函数中有定义,DPR 中没有设置这个超参,相当于对应的 $$\tau$$ 为 1。$$\tau$$ 参数的值越小,相似度分布的“温度”越低,表示对相似度的判别更加严格。意味着在训练时,相似的句子会更有可能被赋予较高的概率,而相似度较低的句子则会有较低的概率。通过调节 $$\tau$$ 参数,可以影响相似度分布的平滑程度,从而对模型的训练产生影响。

题外话,这个工作 Danqi 大佬也有参与,Danqi 大佬在同年还指导了另一篇类似思想的工作 [DensePhrases](https://arxiv.org/pdf/2012.12624.pdf),主要侧重于 phrase 的检索。

## 负样本采样
训练数据中,正样本由 (Query, Answer) 构成,但负样本的采集就需要人工构造了。DPR 考虑了三种方式:
1. 从语料中随机采取段落;
Expand Down
11 changes: 10 additions & 1 deletion _posts/2023/2023-11-16-Retrivial-augmented-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ author: berrysleaf
ChatGPT 爆火之后,有一段时间内很多公司都在竞相做向量数据库,一些数据库厂商也在竞相在传统数据库上增加向量存储功能。常见的做法是通过预训练获取一个大模型,然后将数据向量化并存储。


关于检索,一个比较引入注目的技术是 RAG (Retrieval-Augmented Generation)。这个技术与 meta 于 2020 年在论文 [《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》](https://arxiv.org/pdf/2005.11401.pdf) 中提出,它是一个检索增强的生成模型,通过检索得到的上下文信息来指导生成,从而提高生成的质量。这里面有两个主要的模块,一个是检索,使用的技术是 DPR(Dense Passage Retrieval),即是 20 年出的暴打前浪 BM25 的技术,同样也是 meta 的工作,这个工作丹琦大佬也有参与。关于 DPR 的结构,我们在之前的文章[《Okapi-BM25》]({{site.baseur}}/2022/11/17/Okapi-BM25/)里稍有提过。另一个模块是 seq2seq 生成器,模型使用的 [BART](https://arxiv.org/abs/1910.13461)(也是 meta 的工作)。
关于检索,一个比较引入注目的技术是 RAG (Retrieval-Augmented Generation)。这个技术与 meta 于 2020 年在论文 [《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》](https://arxiv.org/pdf/2005.11401.pdf) 中被提出,它是一个检索增强的生成模型,通过检索得到的上下文信息来指导生成,从而提高生成的质量。这里面有两个主要的模块,一个是检索,使用的技术是 DPR(Dense Passage Retrieval),即是 20 年推出的“暴打前浪 BM25 的技术,同样也是 meta 的工作。DPR 整体结构是一个 dual encoder: document encoder 和 query encoder,两个 encoder 使用的模型都是 $$\text{BERT}_\text{BASE}$$。关于 DPR 的机制,我们在之前的文章[《Okapi-BM25》]({{site.baseur}}/2022/11/17/Okapi-BM25/)里稍有提过。另一个模块是 seq2seq 生成器,模型使用的 [BART-large](https://arxiv.org/abs/1910.13461)(也是 meta 的工作)。


# RAG 结构
Expand All @@ -49,6 +49,15 @@ $$\begin{aligned}
p_\text{RAG-token}(y\lvert x) &\approx \prod_{i=1}^N \sum_{z\in \mathcal{Z}} p_\eta(z\lvert x)p_\theta(y_i\lvert x, z, y_{1,...,i-1})
\end{aligned}$$

# 训练
RAG 联合训练检索模块跟生成模块,不需要关于检索文档的监督信息(这也是为什么文中说将 document 视为潜变量),在给定数据集 $$\mathcal{D} = \{(x^{(i)}, y^{(i)})\}_{i=1}^N$$ 的情况下,最小下面的负对数似然:

$$\begin{aligned}
\mathcal{L}(\theta, \eta) &= - \sum_{(x, y)\in \mathcal{D}} \log p_\text{RAG}(y\lvert x) \\\
&= - \sum_{(x, y)\in \mathcal{D}} \log \sum_{z\in \mathcal{Z}} p_\eta(z\lvert x)p_\theta(y\lvert x, z)
\end{aligned}$$

在训练过程,由于更新检索库的编码器消耗巨大,因为每更新一次文档编码器就需要对所有文档重新编码,所以在训练过程 RAG 选择固定文档编码器 $$\text{BERT}_d$$ 的参数,只训练 $$\text{BERT}_q$$ 编码器与 BART 生成器。

# 效果

Expand Down
2 changes: 1 addition & 1 deletion assets/progress.json
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,6 @@
],
[
"2023-11-18",
323
507
]
]

0 comments on commit 7ff987f

Please sign in to comment.