Hierarchical Graph Network for Multi-hop Question Answering

HGN

由4个模块组成:

  1. Graph Construction Module
  2. Context Encoding Module
  3. Graph Reasoning Module
  4. Multi-task Prediction Module

Graph Construction

  1. 找到最相关的段落

    1. 训练了一个Roberta + binary_cls 检索相关的段落(匹配段落标题和问题)
    2. 如果多个都相关,则选择得分最高的两个
    3. 如果title matching没有结果,那么找包含问题实体的段落
    4. 如果实体也没有结果,那么只用段落匹配的烦最高的两个(first hop)

    接下来是第二跳,寻找能够链接到其他相关段落的facts and entities

    1. entity linking 可能引入噪音, 直接使用wikipedia中的超链接找第二跳的段落
    2. 在含有超链接的句子和链接到的段落之间添加双向边
  2. 为上一步中相关的段落和其中的实体之间添加边

    paragraphs -> sentences -> entities

    1. 每个段落节点,添加到包含的句子的节点的边
    2. 每个句子节点,抽取所有实体,并添加边
    3. 段落之间,句子之间也可以有边

    不同类型的节点从不同的信息源捕获语义, 能比只有同类节点的图更精确地定位证据和找到答案

image-20200708151430988.png

定义了七种边:1. question-paragraph 2. question - its inner entities 3. paragraph - its inner entities 4. sentence - their linked paragraph 5.sentence - its inner entities 6. paragraph-paragraph 7. sentence -sentence (appear in the same paragraph)

Context Encoding

​ 把所有段落拼接起来得到 context C 再拼接上question Q 输入RoBERTa, 后面再接一个Bi-attention layer。 Q = \{q_0, q_1,...,q_{m-1}\} \in \mathcal{R}^{m \times d}C = \{ c_0, c_1,...,c_{n-1}\} \in \mathcal{R}^{n \times d} , m,n 分别是 question 和 context的长度。

​ 在C后用BILSTM, 从BiLSTM的输出 M \in \mathcal{R}^{n \times 2d} 中得到不同节点的表示。entity/sentence/paragraph是context中的一个区域,表示的计算方式:1. 反向LSTM 在开始处的隐藏状态 2. 前向LSTM在结束位置的隐藏状态。 而对于question节点,使用max-pooling获得其表示。
p_i = MLP_1([M[P_{start}^i][d:];M[P_{end}^i][:d]])\\ s_i = MLP_2([M[S_{start}^i][d:];M[S_{end}^i][:d]]) \\ e_i = MLP_3([M[E_{start}^i][d:];M[E_{end}^i][:d]]) \\ q = max-pooling(Q)

Graph Reasoning

contextualized representations of all the graph nodes are transformed into higher-level features via a graph neural network.

P = \{p_i\}_{i=1}^{n_p}S = \{s_i\}_{i=1}^{n_s}E = \{e_i\}_{i=1}^{n_e}n_p, n_s, n_e 分别是图中该类型节点的个数, 实验中分别设置为4, 40, 60 (padded where necessary). H = \{q,P,S,E\} \in \mathcal{R}^{g \times d} , 其中g = n_p +n_s + n_e +1

​ 使用GAT进行图中的信息传播。GAT 将所有节点作为输入,通过对应的邻居节点\mathcal{N}_i更新每个节点的表示 h_i'
h_i' = LeakyRelu(\sum_{j \in \mathcal{N}_i} \alpha_{ij}h_jW)\\ \alpha_{ij} = \frac{exp(f([h_i;h_j]w_{e_{ij}}))}{\sum_{k\in\mathcal{N}_i}exp(f([h_i;h_k]w_{e_{ik}}))}
W_{e_{ij}} \in \mathcal{R}^{2d} 是 实体类型节点中第i和第j个节点之间边的权重。f(\cdot) 是LeakyRelu。

在图推理之后,我们得到H' = {h_0',h_i',...,h_g'} \in \mathcal{R}^{g \times d} = \{q',P',S',E'\}

门注意力 图的信息将进一步交互上下文信息来预测答案区域:
C= Relu(MW_m) \cdot Relu(H'W'_m)^T \\ \bar{H} = Softmax(C)\cdot H'\\ G = \sigma([M;\bar{H}]W_s \cdot Tanh([M;\bar{H}]W_t))
其中W_m \in \mathcal{R}^{2d \times 2d}, W'_m \in \mathcal{R}^{2d \times 2d}, W_s \in \mathcal{R}^{4d \times 4d}, W_t \in \mathcal{R}^{4d \times 4d} 都是需要学习的权重, G \in \mathcal{R}^{n \times 4d} 是gated representation ,可以用来做答案区域抽取。

多任务预测

​ 图推理之后,更新过的节点表示被用于不同的子任务:1.paragraph selection based on paragraph nodes; 2.supporting facts prediction based on sentence nodes; 3.answer prediction based on entity nodes and context representation G

​ 由于答案可能不太实体节点中, 实体节点的损失只用作正则项。最终的目标函数:
\mathcal{L}_{joint} = \mathcal{L}_{start} + \mathcal{L}_{end}+\lambda_1\mathcal{L}_{para} + \lambda_2\mathcal{L}_{sent} + \lambda_3\mathcal{L}_{entity} + \lambda_4\mathcal{L}_{type}
对于段落选择 (\mathcal{L}_{para})和支持证据预测(\mathcal{L}_{sent})使用两层MLP作为二分类器。
o_{sent} = MLP_4(S') \in \mathcal{R}^{n_s}, o_{para} = MLP_5(P') \in \mathcal{R}^{n_p}
实体预测(\mathcal{L}_{entity})被视作多标签分类任务,候选实体包括问题中的所有实体和those that match the titles in the context。 如果正确答案不在这些实体节点中,则实体损失函数为0.
\mathbf{o}_{e n t i t y}=\operatorname{MLP}_{6}\left(\mathbf{E}^{\prime}\right)
实体损失值只作为正则项,最终的答案预测只依赖下面的答案区域抽取模块。每个位置上是答案开始或结束的概率也用两层MLP计算:
\mathbf{o}_{\text {start}}=\operatorname{MLP}_{7}(\mathbf{G}), \mathbf{o}_{\text {end}}=\operatorname{MLP}_{8}(\mathbf{G})
对于答案类型预测,span, entity, yes, no
\mathbf{o}_{type}=\operatorname{MLP}_{9}\left(\mathbf{G}[0]\right)
最终的交叉熵损失使用以上的logits计算:\mathbf{o}_{\text {sent}}, \mathbf{o}_{\text {para}}, \mathbf{o}_{\text {entity}}, \mathbf{o}_{\text {start}}, \mathbf{o}_{\text {end}}, \mathbf{o}_{\text {type}}

实验

image-20200708184940036.png

Hotpot QA

image-20200708185014809.png

Fullwiki setting.

image-20200708185302850.png

表3:PS Graph 只有问题到段落,段落到句子的边(边的类型:1,3,4); PSE 加入了Entity相关的,边的类型增加了 2,5 ; 最终的HGN 再多了6,7两种类型的边。

表4:验证了损失函数设计的有效性

表5:不同预训练模型的效果

image-20200708185937540.png

错误分析:pass

结论

本文提出了层叠图网络 HGN, 为了从不同粒度级别获取线索,HGN模型将异质节点放进一个图中。实验在HotpotQA 取得了最好成绩(2019 Dec)。现在fullwiki setting, 现成的段落检索器被用来从大规模语料文本中选择相关的context。未来将探索HGN和段落检索器之间的交互和联合训练。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容