【转载】TF-Slim简介

原文链接:blog.csdn.net/zchang81/article/details/77770880

slim作为一种轻量级的tensorflow库,使得模型的构建,训练,测试都变得更加简单。

使用方法:

import tensorflow.contrib.slimas slim

组成部分:

arg_scope: 使得用户可以在同一个arg_scope中使用默认的参数

data,evaluation,layers,learning,losses,metrics,nets,queues,regularizers,variables

定义模型

在slim中,组合使用variables, layers和scopes可以简洁的定义模型。

(1)variables: 定义于variables.py。生成一个weight变量, 用truncated normal初始化它, 并使用l2正则化,并将其放置于CPU上, 只需下面的代码即可:

weights= slim.variable('weights',shape=[10,10,3 ,3],

initializer=tf.truncated_normal_initializer(stddev=0.1),

regularizer=slim.l2_regularizer(0.05),

device='/CPU:0')

原生tensorflow包含两类变量:普通变量和局部变量。大部分变量都是普通变量,它们一旦生成就可以通过使用saver存入硬盘,局部变量只在session中存在,不会保存。

slim进一步的区分了变量类型,定义了model variables,这种变量代表了模型的参数。模型变量通过训练活着微调而得到学习,或者在评测或前向传播中可以从ckpt文件中载入。

非模型参数在实际前向传播中不需要的参数,比如global_step。同样的,移动平均反应了模型参数,但它本身不是模型参数。例子见下:

# Model Variables

weights= slim.model_variable('weights',shape=[10,10,3 ,3],initializer=tf.truncated_normal_initializer(stddev=0.1),regularizer=slim.l2_regularizer(0.05),device='/CPU:0')

model_variables= slim.get_model_variables()

# model_variables包含weights# Regular variables

my_var= slim.variable('my_var',shape=[20,1],initializer=tf.zeros_initializer())

regular_variables_and_model_variables= slim.get_variables()

#get_variables()得到模型参数和常规参数

当我们通过slim的layers或着直接使用slim.model_variable创建变量时,tf会将此变量加入tf.GraphKeys.MODEL_VARIABLES这个集合中,当你需要构建自己的变量时,可以通过以下代码

将其加入模型参数。

my_model_variable= CreateViaCustomCode()# Letting TF-Slim know about the additional variable.

slim.add_model_variable(my_model_variable)

(2)layers:抽象并封装了常用的层,并且提供了repeat和stack操作,使得定义网络更加方便。

下面的代码利用repeat实现了三个卷积层的堆叠

net= slim.repeat(net,3, slim.conv2d,256, [3,3],scope='conv3')

repeat不仅只实现了相同操作相同参数的重复,它还将scope进行了展开,例子中的scope被展开为 'conv3/conv3_1', 'conv3/conv3_2' and 'conv3/conv3_3'。

slim.stack操作使得我们可以重复的讲同一个操作以不同参数一次作用于一些层,这些层的输入输出时串联起来的。比如:

slim.stack(x, slim.fully_connected, [32,64,128],scope='fc')

slim.stack(x, slim.conv2d, [(32, [3,3]), (32, [1,1]), (64, [3,3]), (64, [1,1])],scope='core')

(3)scopes:除了tensorflow中的name_scope和variable_scope, tf.slim新增了arg_scope操作,这一操作符可以让定义在这一scope中的操作共享参数,即如不制定参数的话,则使用默认参数。且参数可以被局部覆盖。使得代码更加简洁,如下:

with slim.arg_scope([slim.conv2d],padding='SAME',weights_initializer=tf.truncated_normal_initializer(stddev=0.01)weights_regularizer=slim.l2_regularizer(0.0005)):

net= slim.conv2d(inputs,64, [11,11],scope='conv1')

net= slim.conv2d(net,128, [11,11],padding='VALID',scope='conv2')

net= slim.conv2d(net,256, [11,11],scope='conv3')

而且,我们也可以嵌套多个arg_scope在其中使用多个操作。

训练模型Tensorflow的模型训练需要模型,损失函数,梯度计算,以及根据loss的梯度迭代更新参数。

(1)losses使用现有的loss:

loss= slim.losses.softmax_cross_entropy(predictions, labels)

对于多任务学习的loss,可以使用:

# Define the loss functions and get the total loss.

classification_loss= slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)

sum_of_squares_loss= slim.losses.sum_of_squares(depth_predictions, depth_labels)# The following two lines have the same effect:

total_loss= classification_loss+ sum_of_squares_loss

total_loss= slim.losses.get_total_loss(add_regularization_losses=False)

如果使用了自己定义的loss,而又想使用slim的loss管理机制,可以使用:

pose_loss= MyCustomLossFunction(pose_predictions, pose_labels)

slim.losses.add_loss(pose_loss)

total_loss= slim.losses.get_total_loss()

#total_loss中包涵了pose_loss

(2) 训练循环

slim在learning.py中提供了一个简单而有用的训练模型的工具。我们只需调用

slim.learning.create_train_op和slim.learning.train就可以完成优化过程。

g= tf.Graph()# Create the model and specify the losses......

total_loss= slim.losses.get_total_loss()

optimizer= tf.train.GradientDescentOptimizer(learning_rate)# create_train_op ensures that each time we ask for the loss, the update_ops# are run and the gradients being computed are applied too.

train_op= slim.learning.create_train_op(total_loss, optimizer)

logdir=...# Where checkpoints are stored.

slim.learning.train(

train_op,

logdir,number_of_steps=1000,#迭代次数save_summaries_secs=300,#存summary间隔秒数save_interval_secs=600)#存模型建个秒数

(3)训练的例子:

import tensorflowas tfslim= tf.contrib.slimvgg= tf.contrib.slim.nets.vgg...train_log_dir=...ifnot tf.gfile.Exists(train_log_dir):  tf.gfile.MakeDirs(train_log_dir)with tf.Graph().as_default():# Set up the data loading:  images, labels=...# Define the model:  predictions= vgg.vgg16(images,is_training=True)# Specify the loss function:  slim.losses.softmax_cross_entropy(predictions, labels)  total_loss= slim.losses.get_total_loss()  tf.summary.scalar('losses/total_loss', total_loss)# Specify the optimization scheme:  optimizer= tf.train.GradientDescentOptimizer(learning_rate=.001)# create_train_op that ensures that when we evaluate it to get the loss,# the update_ops are done and the gradient updates are computed.  train_tensor= slim.learning.create_train_op(total_loss, optimizer)# Actually runs training.

slim.learning.train(train_tensor, train_log_dir)

根据已有模型进行微调

(1)利用tf.train.Saver()从checkpoint恢复模型

# Create some variables.v1= tf.Variable(...,name="v1")v2= tf.Variable(...,name="v2")...# Add ops to restore all the variables.restorer= tf.train.Saver()# Add ops to restore some variables.restorer= tf.train.Saver([v1, v2])# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session()as sess:# Restore variables from disk.  restorer.restore(sess,"/tmp/model.ckpt")print("Model restored.")# Do some work with the model...

(2)部分恢复模型参数

# Create some variables.v1= slim.variable(name="v1",...)v2= slim.variable(name="nested/v2",...)...# Get list of variables to restore (which contains only 'v2'). These are all# equivalent methods:variables_to_restore= slim.get_variables_by_name("v2")# orvariables_to_restore= slim.get_variables_by_suffix("2")# orvariables_to_restore= slim.get_variables(scope="nested")# orvariables_to_restore= slim.get_variables_to_restore(include=["nested"])# orvariables_to_restore= slim.get_variables_to_restore(exclude=["v1"])# Create the saver which will be used to restore the variables.restorer= tf.train.Saver(variables_to_restore)with tf.Session()as sess:# Restore variables from disk.  restorer.restore(sess,"/tmp/model.ckpt")print("Model restored.")# Do some work with the model...

(3)当图的变量名与checkpoint中的变量名不同时,恢复模型参数当从checkpoint文件中恢复变量时,Saver在checkpoint文件中定位到变量名,并且把它们映射到当前图中的变量中。之前的例子中,我们创建了Saver,并为其提供了变量列表作为参数。这时,在checkpoint文件中定位的变量名,是隐含地从每个作为参数给出的变量的var.op.name而获得的。这一方式在图与checkpoint文件中变量名字相同时,可以很好的工作。而当名字不同时,必须给Saver提供一个将checkpoint文件中的变量名映射到图中的每个变量的字典,例子见下:

# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'defname_in_checkpoint(var):return'vgg16/'+ var.op.name# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'defname_in_checkpoint(var):if"weights"in var.op.name:return var.op.name.replace("weights","params1")if"bias"in var.op.name:return var.op.name.replace("bias","params2")

variables_to_restore= slim.get_model_variables()

variables_to_restore= {name_in_checkpoint(var):varfor varin variables_to_restore}

restorer= tf.train.Saver(variables_to_restore)with tf.Session()as sess:# Restore variables from disk.

restorer.restore(sess,"/tmp/model.ckpt")

(4)在一个不同的任务上对网络进行微调

比如我们要将1000类的imagenet分类任务应用于20类的Pascal VOC分类任务中,我们只导入部分层,见下例:

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容