前些天从实验室了解到天池的FashionAI全球挑战赛,题目和数据都挺有意思,于是花了点时间稍微尝试了下。目前比赛还在初赛阶段,题目有两个,分别是服装属性标签识别和服饰关键点定位。
服装属性标签识别是指识别出领、袖、衣、裙、裤等部位的设计属性,对应多个多分类问题,例如以下的例子。
服饰关键点定位是指定位出服饰中关键点的位置,对应多个回归问题,例如以下的例子。
这和CelebA(http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)人脸数据集有点像,每张图片都是一张人脸,对应的标注包括5个关键点位置和40个属性的01标注,例如是否有眼镜、帽子、胡子等。
选题
我选了第二道题目,一方面感觉有人脸关键点检测、人体骨骼关键点检测等相关问题可供参考,去年的AI Challenger也举办过人体骨骼关键点检测的比赛(https://challenger.ai/competition/keypoint),另一方面自己还没有做过这块内容,比较感兴趣。
比赛官方提供的训练集共包括4W多张图片,测试集共包括将近1W张图片,神奇的是训练集和测试集有366张图片是完全重合的。每张图片都指定了对应的服饰,共5类:上衣(blouse)、外套(outwear)、连身裙(dress)、半身裙(skirt)、裤子(trousers)。
一共有24个关键点,但每类服饰对应的关键点数量并不一样。关键点标注分为三类,-1表示不存在,0表示存在但不可见,1表示存在且可见,后两种情况都会提供对应关键点的xy坐标。
初步探索
进行了相关的预处理之后,先尝试下最基础的结构,即卷积、池化加全连接。由于之前完全没做过这一块,所以网络的一些细节都只能慢慢尝试,包括卷积用几层、卷积核大小设多少、使用哪个激活函数、使用哪个损失函数等等。
进行了长时间的摸索,终于折腾出第一个全部打通的版本,提交了一版结果,成绩大概30%左右。由于最后一层使用全连接层直接输出每个关键点的xy坐标,因此误差比较大。
目前排行榜第一名是4.49%,占榜很多天雷打不动,可见实力之强劲。
进一步探索
后来一想,像这类比较经典的问题,肯定已经有大量的相关研究和模型,完全靠自己凭空搭一个网络显然不靠谱。于是进行了一些调研,找到两个模型:Convolutional Pose Machine、Stacked Hourglass。
精力有限,重点研究了一下CPM。阅读了对应的论文,2016年的CVPR,模型结构长这样,简单来说就是反复使用多个Stage,不断抽取每个关键点对应的越来越准确的响应图。
在Github上找到了CPM的一个开源实现(https://github.com/timctho/convolutional-pose-machines-tensorflow),阅读代码并进行修改后应用到比赛的数据上,在P100上训练共花费30个小时左右。使用6个Stage的CPM,为每个关键点生成一张响应图。
以下是一张dress对应的结果,第一行的三张依次是第1个、第2个、第3个Stage的响应图合成结果,第二行的三张分别对应第6个Stage的响应图合成结果、正确答案、正确答案和原图的合成,看起来还不错。
再来看个outwear,响应图也很准。
最后再看个trousers,关键点比较少,也很准。
又交了一版结果,拿到了17%的成绩。由于CPM输出关键点的响应图而不是直接输出关键点的xy坐标,同时使用多个Stage级联以逐步获取越来越准确的响应图,因此可以取得更好的结果。
可能的改进
可能的改进包括多个方面:
- 使用其他更新更好的模型。由于我是这方面的外行,一方面不清楚哪个模型目前最好,另一方面再调研和实现一个模型也需要耗费大量时间;
- 调参。每个模型都涉及很多参数,除此之外还有很多和模型无关的参数,例如学习率、批大小、正则项系数等,由于我是这方面的外行,暂时不清楚对关键点检测这类问题该如何选择参数;
- 使用数据增强等技巧。由于我是这方面的外行,暂时不清楚除了数据增强之外还有什么适合关键点检测这类问题的技巧。
虽然有很多可能的改进方向,不过由于自己之前没有做过关键点检测这类问题,所以继续折腾下去只能靠运气各种尝试,而且每次尝试都需要等待很久的模型训练时间。
相比之下,对于一些在关键点检测领域有相当积累的团队和个人,他们有着丰富的经验和现成的代码,和他们竞争还是相当有难度的。看一下排行榜,第一名的4.49%至今无人能超越,前三十名也都在12%以下。
而且个人事情也比较多,时间和精力都十分有限,所以决定这个比赛不再继续尝试,感觉做到这一步就可以了。
总结
通过这次比赛,了解了关键点检测这类问题的一些解决方法,并尝试用CPM进行了一些实践,对自己而言已经满足了。
等比赛结束后,再关注一下冠军团队的解决方案,好好学习一波。