深度残差收缩网络(附TFLearn代码)

深度残差收缩网络是一种特征学习方法,是深度残差网络(deep residual network,ResNet)的一种变体。本文进行简短的介绍。

一. 深度残差网络基础

深度残差网络的主干部分是由许多基本模块堆叠而成的。一个基本模块包含两条路径:残差路径和恒等路径。相较于普通的卷积神经网络,残差路径是深度残差网络取得优异性能的关键。深度残差网络的基本模块如下图所示:

图1 深度残差网络的基本模块

二. 深度残差收缩网络的基本模块

顾名思义,深度残差收缩网络,对“残差路径”进行了“收缩”。这里的“收缩”,指的就是软阈值化。软阈值化是许多信号降噪算法的关键步骤。深度残差收缩网络将软阈值化集成进了网络框架之中,以实现在深度神经网络的内部消除噪声所对应冗余信息的目的。深度残差收缩网络的基本模块如下图所示:

图2 深度残差收缩网络的基本模块

三. 深度残差收缩网络的整体结构

深度残差收缩网络的其他部分是和深度残差网络一样的,其主体部分就是堆叠图2中的基本模块。其整体结构如下图所示:

图3 整体结构

4.图像分类的TFLearn代码

# -*- coding: utf-8 -*-
"""
Created on Tue Jan  7 11:55:14 2020

Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

@author: Thinkpad
"""

import tflearn
import tensorflow as tf
from tflearn.layers.conv import conv_2d
import numpy as np

# Load data
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
 
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
 
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)

with tf.Graph().as_default():

    def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                       downsample_strides=2, activation='relu', batch_norm=True,
                       bias=True, weights_init='variance_scaling',
                       bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                       trainable=True, restore=True, reuse=False, scope=None,
                       name="ResidualBlock"):
       
        # residual shrinkage blocks with channel-wise thresholds
   
        residual = incoming
        in_channels = incoming.get_shape().as_list()[-1]
   
        # Variable Scope fix for older TF
        try:
            vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                       reuse=reuse)
        except Exception:
            vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
   
        with vscope as scope:
            name = scope.name #TODO
   
            for i in range(nb_blocks):
   
                identity = residual
   
                if not downsample:
                    downsample_strides = 1
   
                if batch_norm:
                    residual = tflearn.batch_normalization(residual)
                residual = tflearn.activation(residual, activation)
                residual = conv_2d(residual, out_channels, 3,
                                 downsample_strides, 'same', 'linear',
                                 bias, weights_init, bias_init,
                                 regularizer, weight_decay, trainable,
                                 restore)
   
                if batch_norm:
                    residual = tflearn.batch_normalization(residual)
                residual = tflearn.activation(residual, activation)
                residual = conv_2d(residual, out_channels, 3, 1, 'same',
                                 'linear', bias, weights_init,
                                 bias_init, regularizer, weight_decay,
                                 trainable, restore)
               
                # get thresholds and apply thresholding
                abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
                scales = tflearn.fully_connected(abs_mean, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
                scales = tflearn.batch_normalization(scales)
                scales = tflearn.activation(scales, 'relu')
                scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
                scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
                thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
                residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
               
   
                # Downsampling
                if downsample_strides > 1:
                    identity = tflearn.avg_pool_2d(identity, 1, downsample_strides)
   
                # Projection to new dimension
                if in_channels != out_channels:
                    if (out_channels - in_channels) % 2 == 0:
                        ch = (out_channels - in_channels)//2
                        identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]])
                    else:
                        ch = (out_channels - in_channels)//2
                        identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                    in_channels = out_channels
   
                residual = residual + identity
   
        return residual
   
    # Real-time data preprocessing
    img_prep = tflearn.ImagePreprocessing()
    img_prep.add_featurewise_zero_center(per_channel=True)
     
    # Real-time data augmentation
    img_aug = tflearn.ImageAugmentation()
    img_aug.add_random_flip_leftright()
    img_aug.add_random_crop([32, 32], padding=4)
    img_aug.add_random_rotation(max_angle=20.)
   
    # Building A Deep Residual Shrinkage Network
    net = tflearn.input_data(shape=[None, 32, 32, 3],
                             data_preprocessing=img_prep,
                             data_augmentation=img_aug)
    net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
    net = residual_shrinkage_block(net, 1, 8, downsample=False)
    net = residual_shrinkage_block(net, 1, 16, downsample=True)
    net = residual_shrinkage_block(net, 1, 16, downsample=False)
    net = residual_shrinkage_block(net, 1, 32, downsample=True)
    net = residual_shrinkage_block(net, 1, 32, downsample=False)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, 'relu')
    net = tflearn.global_avg_pool(net)
    # Regression
    net = tflearn.fully_connected(net, 10, activation='softmax')
    mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=100000, staircase=True)
    net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
    # Training
    model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
   
    model.fit(X, Y, n_epoch=500, snapshot_epoch=False, snapshot_step=500,
              show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
   
    DRSN_training_acc = model.evaluate(X, Y)[0]
    DRSN_validation_acc = model.evaluate(testX, testY)[0]

Zhao M, Zhong S, Fu X, et al. Deep residual shrinkage networks for fault diagnosis[J]. IEEE Transactions on Industrial Informatics, 2019.

https://ieeexplore.ieee.org/document/8850096

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,293评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,604评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,958评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,729评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,719评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,630评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,000评论 3 397
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,665评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,909评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,646评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,726评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,400评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,986评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,959评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,996评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,481评论 2 342