使用pytorch实现Image Captioning encoder-decoder框架

第一篇简书,记录下那些走过的路和踩过的坑。


轮子来源:GitHub - ruotianluo/ImageCaptioning.pytorch: Image captioning codebase in pytorch
本文包括:
1.Image Captioning 的简介与学习参考文献
2.使用上面代码进行对模型的训练和评估
3.修改多GPU训练的BUG
4.多GPU预训练模型读取
5.自定义模型


框架简介

Image Captioning是计算机视觉的研究方向之一,其中文翻译一般为图像的文本描述。其任务大概可以描述为输入一张图片,生成一句对此图片的描述句子。作为一种结合了计算机视觉和自然语言翻译的多模态任务,其方法随着深度学习的兴起,也能大概有个推测。视觉方面一般使用CNN对图像进行编码(encoder),再输入到NLP中常用的RNN网络中进行句子的生成(decoder)。两者结合,形成一个端对端的网络结构,这样当使用训练好的模型时,输入图片就可以直接输出预测的句子啦。

学习参考

因为是两个领域(CV、NLP)的结合,所以想要研究透彻取得新成果,那么两个方向的研究动态都需要着重关注。如果只是用一用试一试那么有一点神经网络的基础知识就可以了。详细原理懒得叙述了,网上很多已经写的足够清楚明白,但我还是推荐读一读相关的论文。

相关论文如下:

大名鼎鼎的开山之作Show and Tell: A Neural Image Caption Generator 四个谷歌老哥列在一起就问你怕不怕。encoder使用的自家GoogleNet,decoder使用的LSTM,这个方向的很多论文参考必有这篇,虽然性能在现在看来并不算太好(但可以看下里面和当年的那些方法的效果对比),但是有了这篇,这个方向才有了现在这么多的关注。不得不说,大公司还是牛啊。

只有encoder-decoder可能还不够,加个attention感觉效果貌似更好 Show,Attend and Tell: Neural Image Caption Generation with Visual Attention,在原有框架下加入了attention机制cnn也换成了时下流行的ResNet(P.S. 自家电脑硬盘小的不建议用attention,因为要提取 MSCOCO数据集的feature map,200多G的大小)

之后的研究进展就是在这个框架下改进了,改改encoder(图像预处理+tricks),改改decoder(改良版LSTM),改改attention(Self-criticalAdaptive Attention)。

到了最近,看到一篇今年NIPS的文章,直接放弃了这个框架使用了新的框架,貌似取得了更牛逼的效果。还没看完,后面再找时间介绍。


回归正题,开始动手用轮子吧。

框架使用

P.S.大部分链接需要梯子,请准备好工具

代码使用的是python 2.7编写的,pytorch版本为0.4.1(我使用1.0.0跑也没多大问题),环境是ubuntu 16CUDA 用7以上吧和对应的CUDNN。这个老哥写得很好,操作指南可以就读他的README就行了,翻译就找Chrome,一键全搞定。

大概流程如下:

1.下载预训练好的ResNet模型并放入data/imagenet_weights文件夹中。50、101、152分别对应网络的层数,提取的feature map大小依次递增。

101提取出来大概200多个G,152提取的话是368.76G。

2.如果不想自己训练,作者提供了预训练好的文件,使用eval.py进行评估就可以了。

3.这里下载MSCOCO数据集,请用2014的。将解压出来的train2014/val2014/放在同一目录中。从这里下载预处理的COCO字幕。dataset_coco.json从zip文件中提取并将其复制到data/文件夹中。此文件提供预处理的标题以及标准的train-val-test拆分。
注意! 请从这里下载此图片,并放入train2014文件夹替换原始图像。

4.使用prepro_*.py来读取上面的data并创建一个数据集。
打开你的命令行,cd到对应文件夹里:

处理数据集,得到标签等文件:

python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk

使用ResNet处理MSCOCO,得到全连接层的数据,和最后一层的Feature Map:
其中$IMAGE_ROOT就是放train/val文件夹的路径

prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT

经过1天左右的特征提取,我们现在得到了两个文件夹:
data/cocotalk_fc是全连接层的特征,data/cocotalk_att就是ResNet最后一层的feature map(这东西200多个G,请不要用记事本之类的打开)。
两个文件都为*.h5文件,载入的时候是词典类型。其结构为:

#fc    
{'img_id':[2048]#一个2048维array}
#att
{'img_id':[14][14][2048]#14,14,2048维array}

如果想用这两个东西干其他的事可以直接调用,里面的数据类型为numpy.arraydtype=float32。若使用pytorch,请将其转换为tensor形式。

4.开始训练,命令行输入:

python train.py --id st --caption_model show_tell --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_st --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 25

这些参数的作用,可以在opts.py文件中得到详细的解释。更加详细的训练和评估都在README文件中,可以仔细阅读。


好了,到此就算是把这个程序跑通了,下面开始加东西。

多GPU训练BUG修复

当电脑有多个GPU时候,可以调用pytorch的torch.nnDataParallel将model送入其中,就可以让GPU并行训练。
目前1.0版本的pytorch可以用更好的torch.nn.parallel.DistributedDataParallel中的多线程单GPU方法,具体可以参考pytorch官方文档

但是此框架在调用DataParallel时,训练几个batch后便会报错:
RuntimeError: Gather got an input of invalid size: got [80, 15, 9488], but expected [80, 17, 9488] (gather at torch/csrc/cuda/comm.cpp:183)
这个错误的原因是因为在所使用的CaptionModel.py里。在一个batch的输入后,在主GPU里面进行拆分,然后分给多个GPU并行处理,在最终RNN生成单词概率的时候,是需要相同的长度来使多个输出拼接在一起,从而进行loss的计算。
由于在每个model文件夹时:*.py文件中的forward函数末尾,有一个判断语句:
大概长这样:

#break if all the sequences end
if i >= 1 and seq[:, i].data.sum() == 0:
    break

由于这个判断语句的存在,在生成最大长度句子(默认为17)之前就会中断生成词的过程。
(再解释下原理在训练数据集中一共有可以提取出9488个单词,在加上一个为的组成长为9489的向量,这就是RNN的输出,经过一次softmax后就是每一个单词的概率,bug中的维度就是[batch,17,9488]
在单GPU下,每一个batch是合并的,可能都生成长度小于17的句子的矩阵,再进行loss的计算,这是没问题的。但是多卡时,不同的GPU如果生成不同长度的句子矩阵,那就无法聚合在一起计算loss,所以会出现此BUG。

修改方案

#break if all the sequences end
if i >= 1 and seq[:, i].data.sum() == 0 and len(outputs)==17:#加一个条件,保证长度一致
    break

预训练模型的读取

当使用DataParallel后,我们保存完model的参数,其格式为.pth,在pytorch中使用torch.load()读取为一个字典型dict。

不使用DataParallel训练时,保存下来的词典的每一个键值名都为Model中init中定义的名字。
但当使用DataParallel,每一个键值名前面都会加上一个 model. 所以需要在读取的时候处理下键值:

from collections import OrderedDict

netdata=torch.load(PRETRAIN_MODEL_PATH)#load dict
new_state_dict = OrderedDict()
for k,v in netdata.items():
    name = k [7:] #remove module. 
    new_state_dict [name] = v

自定义model

在model文件夹下,我们如果要定义自己的网络,可以从这几个地方参考:
里面的CapptionModel.py为流程,*core为每个方法的核心方法,一般就是attention的生成过程,想修改新的attention模型可以自己写一个*core的类,再在opts.py中进行参数的添加。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,080评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,422评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,630评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,554评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,662评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,856评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,014评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,752评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,212评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,541评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,687评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,347评论 4 331
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,973评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,777评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,006评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,406评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,576评论 2 349

推荐阅读更多精彩内容