Chen Z, Song Y, Chang T H, et al. Generating Radiology Reports via Memory-driven Transformer[C]//Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP). 2020: 1439-1449.
代码仓:R2Gen code
任务目标
输入一张医学影像,生成相应的报告。
难点
- 医学报告句子很多,使用常用的只生成一句话的Image Captioning models可能不足以生成医学报告。
- 要求的精度也比较高。
医学报告也有特有的特征,图片和报告的格式都高度模式化。目前的解决方案:
- retrieval-based. 大数据集的准备
- retrieval-based + generation-based + manually extracted templates. 模板的准备
- 本文使用的是 generation-based model
模型简介
本文使用memory-dirven Transformer生成医学报告。主要工作:
- 提出了relational memory (RM) 模块记录之前生成过程的信息;
- 提出了memory-driven conditional layer normalization (MCLN) 把RM和Transformer结合起来。
模型结构:Visual Extractor + Encoder + Decoder + Relational Memory
1. Visual Extractor
这一部分的主要任务就是把图像转化为序列数据,从而可以输入到Encoder中。使用常用的卷积神经网络就可以,把最后的Linear去掉,留有最后的patch feature以及fc_feature就可以。
例如本文使用ResNet101预训练模型,每一组数据输入的图像为两张彩色图像。
- 输入shape为(b, 2, c, h, w)
- 视觉提取器分别对两张图进行特征提取。
- ResNet101去除掉最后一层Linear与Pooling层,输入(b, c, h, w),输出(b, 2048, 7, 7)
- 最后经过resize,permutation,shape=(b, 49, 2048)
- 两张图像的特征在axis=1上拼接,得到patch_feature(b, 98, 2048)。这个维度可以看作是batch * seq_len * embedding。
- 第二组特征fc_feature是在patch_feature的基础上再次Pooling生成的(b, 2048),在axis=1上拼接后,得到(b, 4096)。
2. Encoder
编码器把视觉特征处理,使用attention机制,得到最终的特征,作为K,V输入到decoder中。
- 首先有一个src_embedding,把视觉特征维度转换为d_model=512,方便输入Transformer。(按理说这部分是在transformer里实现的,但作者的代码在CaptionModel里实现)
- 数据x.shape=(b, seq_len, d_model)输入后,作为query,key,value输入到attention中(这里d_model=head * d_k),最后又得到相同shape的输出。
3. Relational Memory
这块的设计是为了使模型可以学到更好的report patterns,和retrieval-based 里面模板的准备差不多。RM使用矩阵存储pattern information with each row,称作memory slot。每步生成的过程,矩阵都会更新。在第t步,矩阵用作Q,和前一步输出的embedding拼接起来作为K,V进入到MultiHeadAttention。
Attention
这里的K Q V计算机制与Encoder里的稍有不同
最终attention计算得到的结果记为。因为M是循环计算的,可能梯度消失或者爆炸,因此引入了residual connections 和 gate mechanism。
Residual connection
M的中间值为
Gate Mechanism
gate mechanism的结构如图所示:
输入门和遗忘门用来平衡和,为了方便计算,被broadcast为矩阵, shape和一样。两个门的表达式为
其中的U和W都是可训练参数。最终gate的输出为:
其中是sigmoid activation function, 代表 Hardmard product,也就是pointwise product。
Memory-driven Conditional Layer Normalization
常见的模型memory都在encoder部分,本文单独设计并与decoder紧密联系。与Attention中 LayerNorm对比,提出了MCLN。把用到了Norm里的计算上。主要思路是,把拉成一个向量,再用MLP去预测的变化量,最后再更新。
4. Decoder
共有三个结构,self_attention + src_attention + FFN
输入参数有:
- 输出序列的embedding:tgt,经过了embedding + positionalEncoding,输入到self_attention中,得到的结果作为src_attention的Q,这里使用了tgt_mask
- encoder的输出内容:src,要用到src_attention中的K和V,这里使用了src_mask
- src_mask, tgt_mask 在attention 计算过程中用到
- RM的输出结果memory,每一个t时刻的memory都被拉成了向量,最后拼接在一起,在每一个MCLN中用到。
代码理解
作者的R2Gen模型里EncoderDecoder模块是最复杂的。
- 首先实现了CaptionModel类,可以调用函数,分别执行_forward()和_sample(),实现了beam_search()。
- 然后AttModel继承于CaptionModel类,实现了_sample()函数,在测试过程中用到。
- 最后EncoderDecoder又继承于AttModel,实现了_forward(),在训练过程中用到。
- 最后的搜索过程,也就是_sample() 函数根据不同的策略会有不同的实现。
本文的创新之处在于设计了Relational Memory 模块,并使用到MCLN中。