Skip to content

“Key-Value Memory Networks for Directly Reading Documents”的tensorflow实现方案,使用的数据集是MovieQA

Notifications You must be signed in to change notification settings

lc222/key-value-MemNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

key-value-MemNN

这是论文“Key-Value Memory Networks for Directly Reading Documents”的tensorflow实现方案,使用的数据集是MovieQA,基于KB知识库作为知识源进行建模。

其结构图如下所示:

首先我们看一下模型仿真出来的网络架构,如下图所示:

但是模型的仿真效果比较差,如下图所示,acc一直上不去,而且loss维持在10附近居高不下。到了训练后期

到了训练后期loss反而逐渐升高,模型无法收敛,如下图所示==

观察了模型运行过程中的一些参数,发现原因可能出现在B矩阵的身上,如下图,在训练后期也初夏大规模的震荡。所以对模型做出了改进:

改进的方案是添加一个bias,来减少B矩阵的偏置。此外,去看了FaceBook官网给出的模型训练的参数,发现embedding-size取得500,

而且,max_slots取得1000还是多少,所以尝试按照其参数进行训练(发现训练速度极慢无比)。

所以最后作出的调整是:

1,给B矩阵添加bais

2,进行梯度截断,截断值去10

3,将max_slots增大至300

4,将embedding_size增大至200

按照上述方案进行训练,速度还可以接受,而且在模型收脸上也收到了很好的效果,如下图所示:

而且B矩阵虽然仍有尖峰存在,但是相对而言已经减少了很多,如下

训练过程的输出如下所示,结合上图可以发现,最起码训练集上的acc可以到达0.8~0.9,loss也维持在1附近。 但是仍然存在缺点就是测试集上的效果还不够好,准确度只有0.3左右。这是接下来要进行解决的问题:

========================================================更新=========================================================== 这几天又做了几个对比试验,效果有提升但是还没有达到理想状态,先来更新一下。

首先根据上面的结果呢,我就在思考,为什么准确度和loss波动这么大,而且验证集效果这么差呢。后来想到一个致命的问题,那就是

我们在处理数据的时候,如果一个QA对有好几个答案,我们就会把它作为好几条训练样本,这样就导致了,相同的特征和问题,但是却有

好几个不一样的标签,也就是样本本身就有冲突,这是模型解决不了的。所以改进的第一个问题就是智取第一个答案作为训练样本。结果如下所示:

从上那个图可以发现,训练及准确度提升到了100%,loss降到了0,而且震荡现象也改善了。说明我们的模型取得了巨大的进步。

但是,仍然存在一个问题就是测试集效果依然不好,准确度只有40%多。于是我加了dropout功能,在dropout之分别取0.5和0.8的时候结果如下图所示:

从上图可以发现,dropout程度越厉害,训练集效果越差,测试集的改善也并不明显,而且又重新出现了B矩阵的尖峰现象。很是诡异,等我再挑挑参数看看效果吧。

About

“Key-Value Memory Networks for Directly Reading Documents”的tensorflow实现方案,使用的数据集是MovieQA

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages