论文PePr:Improved Training of Convolutional Filters复现

CSDN上博客地址:https://blog.csdn.net/SweetWind1996/article/details/102859980

GitHub代码地址:https://github.com/SweetWind1996/implementation-of-RePr

参考代码地址:https://github.com/siahuat0727/RePr/blob/master/main.py

论坛讨论地址:https://www.reddit.com/r/MachineLearning/comments/ayh2hf/r_repr_improved_training_of_convolutional_filters/eozi40e/

基本思路:在CNN训练的过程中会产生很多“重叠”的filters,所谓的“重叠”就是指这些filters具有较强的相关性,因此很多filters捕获了重复的特征同时还会造成过拟合的现象。在本文中作者提出了一种正交的方法来让filters从不同的“方向”捕获特征。首先,我们训练一个CNN网络,在训练一段时间后,我们对其卷积层的filters进行相关性排序。在第二步,我们裁减掉相关性较大的前p%的filters,并进行训练。在第二步训练一段时间后,恢复被裁减掉的filters,重复前两步的训练方案。如下算法所示:

image

本文主要对论文进行复现,默认读者对文章已经有一定的了解。本文的代码不一定正确,作以下说明:

1.本文在对filters裁剪时,有两种思路:一是通过阻止被裁减位置权重的梯度更新,二是利用论文DSD的裁剪方案,本文采取的是第二种方案。

2.在对权重进行重新初始化时,与原文相同,利用的是QR分解的方式,先找到null space,在QR分解时先进行了输入矩阵的转置操作。

A^T = QR ->A = RT*QT->AQ = R*T

3.实验的效果不没有论文中的好,实验数据集是CIFAR10,实验模型是我自己写的(共三层卷积层),大家可以根据需要自己写模型,毕竟是验证方法的有效性,保持对比时的模型及参数一致即可。

4.如有任何疑问、建议和意见,欢迎在评论中回复!


以下为代码部分:

def get_convlayername(model):
    '''
    获取卷积层的名称
    # 参数
        model: 神经网络模型
    '''
    layername = []
    for i in range(len(model.layers)):
        # 将模型中所有层的名称存入列表
        layername.append(model.layers[i].name) 
        # 将卷积层分离出来
    convlayername = [layername[name] for name in range(len(layername)) if 'conv2d' in layername[name]] 
    return convlayername[1:] # 不包括第一层
 
def prunefilters(model, convlayername, count=0):
    '''
    裁剪filters
    # 参数
        model: 神经网络模型
        convlayername: 保存所有卷积层(2D)的名称
        count: 用于存储每层filters的起始index
    '''
    convnum = len(convlayername) # 卷积层的个数
    params = [i for i in range(convnum)]
    weight = [i for i in range(convnum)]
    MASK = [i for i in range(convnum)]
    rank = dict() # 初始化存储rank的字典
    drop = []
    index1 = 0
    index2 = 0
    for j in range(convnum):
        # 保存卷积层的权重到一个列表,列表的每个元素是一个数组
        params[j] = model.get_layer(convlayername[j]).get_weights() # 将权重转置后才是正常的数组排列(32,32,3,3)
        weight[j] = params[j][0].T
        filternum = weight[j].shape[0] # 获取每一层filter的个数
        # 初始化一个用于判断正交性的矩阵
        W = np.zeros((weight[j].shape[0], weight[j].shape[2]*weight[j].shape[3]*weight[j].shape[1]), dtype='float32')
        for x in range(filternum):
            # filters是一个列表,它的每一个元素是包含一个卷积层所有filter(1D)的列表
            filter = weight[j][x,:,:,:].flatten()
            filter_length = np.linalg.norm(filter) 
            eps = np.finfo(filter_length.dtype).eps
            filter_length = max([filter_length, eps])
            filter_norm = filter / filter_length # 归一化
            # 将每一层的filters放到矩阵的每一行
            W[x,:] = filter_norm
        # 计算层内正交性
        I = np.identity(filternum)
        P = abs(np.dot(W, W.T) - I)
        O = P.sum(axis=1) / 32 # 计算每行元素之和
        for index, o in enumerate(O):
            rank.update({index+count: o})
        count = filternum + count
    # 对字典进行排序,在所有filters上进行ranking
    ranking = sorted(rank.items(), key=lambda x: x[1]) # ranking为一个列表,其元素是存放键值的元组
    for t in range(int(len(ranking)*0.8), len(ranking)):
        drop.append(ranking[t][0])
    for j in range(convnum):
        MASK[j] = np.ones((weight[j].shape), dtype='float32')
        index2 = weight[j].shape[0] + index1
        for a in drop:
            if a >= index1 and a < index2:
                MASK[j][a-index1,:,:,:] = 0
        index1 = index2
    #     weight[j] = (weight[j] * MASK[j]).T
    # for j in range(convnum):
    #     params[j][0] = weight[j]
    #     model.get_layer(convlayername[j]).set_weights(params[j])
    return MASK, weight, drop, convnum, convlayername
 
 
def Mask(model, mask):
    convlayername = get_convlayername(model)
    for i in range(len(convlayername)):
        Params = [i for i in range(len(convlayername))]
        Weight = [i for i in range(len(convlayername))]
        Params[i] = model.get_layer(convlayername[i]).get_weights() 
        Weight[i] = (Params[i][0].T*mask[i]).T
        Params[i][0] = Weight[i]
        model.get_layer(convlayername[i]).set_weights(Params[i])
 
# 回调函数,每个batch后自动调用
prune_callback = LambdaCallback(
    on_batch_end=lambda batch,logs: Mask(model, mask))
 
def reinit(model, weight, drop, convnum, convlayername):
 
    index1 = 0
    index2 = 0
    new_params = [i for i in range(convnum)]
    new_weight = [i for i in range(convnum)]
    for j in range(convnum):
        new_params[j] = model.get_layer(convlayername[j]).get_weights() 
        new_weight[j] = new_params[j][0].T
    stack_new_filters = new_weight[0]
    stack_filters = weight[0]
    filter_index1 = 0
    filter_index2 = 0
    for i in range(len(new_weight)-1):
        next_new_filter = new_weight[i+1]
        next_filter = weight[i+1]
        stack_new_filters = np.vstack((stack_new_filters, next_new_filter))
        stack_filters = np.vstack((stack_filters, next_filter))
    stack_new_filters_flat = np.zeros((stack_new_filters.shape[0], 
        stack_new_filters.shape[1]*stack_new_filters.shape[2]*stack_new_filters.shape[3]), dtype='float32')
    stack_filters_flat = np.zeros((stack_filters.shape[0], 
        stack_filters.shape[1]*stack_filters.shape[2]*stack_filters.shape[3]), dtype='float32')
    for p in range(stack_new_filters.shape[0]):
        stack_new_filters_flat[p] = stack_new_filters[p].flatten()
        stack_filters_flat[p] = stack_filters[p].flatten()
    q = np.zeros((stack_new_filters_flat.shape[0]), dtype='float32')
    tol = None
    reinit = None
    solve = None
    for b in drop:
        Q, R= qr(stack_new_filters_flat.T)
        for k in range(R.shape[0]):
            if np.abs(np.diag(R)[k])==0:
                # print(k)
                reinit = Q.T[k]
                break
        null_space = reinit
        stack_new_filters_flat[b] = null_space
    for filter_in_stack in range(stack_new_filters_flat.shape[0]):
        stack_new_filters[filter_in_stack] = stack_new_filters_flat[filter_in_stack].reshape(
            (stack_new_filters.shape[1], stack_new_filters.shape[2], stack_new_filters.shape[3]))
    for f in range(len(new_weight)):
        filter_index2 = new_weight[f].shape[0] + filter_index1
        new_weight[f] = stack_new_filters[filter_index1:filter_index2,:,:,:]
        filter_index1 = new_weight[f].shape[0]
        new_params[f][0] = new_weight[f].T
        model.get_layer(convlayername[f]).set_weights(new_params[f]) 

实验效果图:

训练.png

训练+测试.png

我训练了二十多次,实了不同的裁剪率和学习率,效果最好的一次是RePr测试集的正确率比Standard下的训练方式提高了3%。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,012评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,628评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,653评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,485评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,574评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,590评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,596评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,340评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,794评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,102评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,276评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,940评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,583评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,201评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,441评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,173评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,136评论 2 352