import tensorflow as tf
import tensorflow.contrib.layers as tcl
from tensorflow.contrib.framework import arg_scope
import numpy as np
def resBlock(x, num_outputs, kernel_size = 4, stride=1, activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, scope=None):
assert num_outputs%2==0 #num_outputs must be divided by channel_factor(2 here)
with tf.variable_scope(scope, 'resBlock'):
shortcut = x
if stride != 1 or x.get_shape()[3] != num_outputs:
shortcut = tcl.conv2d(shortcut, num_outputs, kernel_size=1, stride=stride,
activation_fn=None, normalizer_fn=None, scope='shortcut')
x = tcl.conv2d(x, num_outputs/2, kernel_size=1, stride=1, padding='SAME')
x = tcl.conv2d(x, num_outputs/2, kernel_size=kernel_size, stride=stride, padding='SAME')
x = tcl.conv2d(x, num_outputs, kernel_size=1, stride=1, activation_fn=None, padding='SAME', normalizer_fn=None)
x += shortcut
x = normalizer_fn(x)
x = activation_fn(x)
return x
class resfcn256(object):
def __init__(self, resolution_inp = 256, resolution_op = 256, channel = 3, name = 'resfcn256'):
self.name = name
self.channel = channel
self.resolution_inp = resolution_inp
self.resolution_op = resolution_op
def __call__(self, x, is_training = True):
with tf.variable_scope(self.name) as scope:
with arg_scope([tcl.batch_norm], is_training=is_training, scale=True):
with arg_scope([tcl.conv2d, tcl.conv2d_transpose], activation_fn=tf.nn.relu,
normalizer_fn=tcl.batch_norm,
biases_initializer=None,
padding='SAME',
weights_regularizer=tcl.l2_regularizer(0.05)):
size = 16
# x: s x s x 3
se = tcl.conv2d(x, num_outputs=size, kernel_size=4, stride=1) # 256 x 256 x 16
se = resBlock(se, num_outputs=size * 2, kernel_size=4, stride=2) # 128 x 128 x 32
se = resBlock(se, num_outputs=size * 2, kernel_size=4, stride=1) # 128 x 128 x 32
se = resBlock(se, num_outputs=size * 4, kernel_size=4, stride=2) # 64 x 64 x 64
se = resBlock(se, num_outputs=size * 4, kernel_size=4, stride=1) # 64 x 64 x 64
se = resBlock(se, num_outputs=size * 8, kernel_size=4, stride=2) # 32 x 32 x 128
se = resBlock(se, num_outputs=size * 8, kernel_size=4, stride=1) # 32 x 32 x 128
se = resBlock(se, num_outputs=size * 16, kernel_size=4, stride=2) # 16 x 16 x 256
se = resBlock(se, num_outputs=size * 16, kernel_size=4, stride=1) # 16 x 16 x 256
se = resBlock(se, num_outputs=size * 32, kernel_size=4, stride=2) # 8 x 8 x 512
se = resBlock(se, num_outputs=size * 32, kernel_size=4, stride=1) # 8 x 8 x 512
pd = tcl.conv2d_transpose(se, size * 32, 4, stride=1) # 8 x 8 x 512
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=2) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=1) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=1) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=2) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=1) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=1) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=2) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=1) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=1) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 2, 4, stride=2) # 128 x 128 x 32
pd = tcl.conv2d_transpose(pd, size * 2, 4, stride=1) # 128 x 128 x 32
pd = tcl.conv2d_transpose(pd, size, 4, stride=2) # 256 x 256 x 16
pd = tcl.conv2d_transpose(pd, size, 4, stride=1) # 256 x 256 x 16
pd = tcl.conv2d_transpose(pd, 3, 4, stride=1) # 256 x 256 x 3
pd = tcl.conv2d_transpose(pd, 3, 4, stride=1) # 256 x 256 x 3
pos = tcl.conv2d_transpose(pd, 3, 4, stride=1, activation_fn = tf.nn.sigmoid)#, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
return pos
@property
def vars(self):
return [var for var in tf.global_variables() if self.name in var.name]
def export_graph(checkpoint_dir, model_name):
'''
model: the defined model
checkpoint_dir: the dir of three files
model_name: the name of .pb
'''
graph = tf.Graph()
with graph.as_default():
### 输入占位符
x = tf.placeholder(tf.float32, shape=[None, 256, 256, 3], name = 'input')
network = resfcn256(256, 256)
x_op = network(x, is_training = True)
output = tf.identity(x_op, name='output_label')
restore_saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
### 初始化变量
sess.run(tf.global_variables_initializer())
### load the model
restore_saver.restore(sess, checkpoint_dir)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [output.op.name])
### 将图写成.pb文件
tf.train.write_graph(output_graph_def, './', model_name, as_text=False)
### 调用函数,生成.pb文件
export_graph('256_256_resfcn256_weight', 'model.pb')
def inference(image):
model_file = tf.gfile.FastGFile('model/model.pb', 'rb')
sess = tf.Session()
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_file.read())
_ = tf.import_graph_def(graph_def, name='')
image = image/255.
image = image[np.newaxis, :, :, :]
input = tf.get_default_graph().get_tensor_by_name("input:0")
output = tf.get_default_graph().get_tensor_by_name("output_label:0")
out = sess.run(output,feed_dict={input:image})
return out[0]
export
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 运行时加载(commonJS) let {state,exists,readFile} =require('fs'...
- 参考地址 https://www.jianshu.com/p/be2d4eab3878 module.export...
- 一 .module.export暴露属性 module.export 是给要当前模块添加属性 , 其中的modul...
- 在JavaScript ES6中,export与export default均可用于导出常量、函数、文件、模块等,...
- import export 在js文件里面可以存在多个 export default 在文件里面只能存在一个