Tensorflow-花分类-图像再训练-part-2-整理翻译

我们继续前一篇文章,来逐步完成图像再训练的整个案例。

在上一篇我们实现了bottleneck文件的创建,如果文件已经创建了,那么我们可以直接使用,否则我们就创建它,我们继续...


取得或创建瓶颈文件的函数get_or_create_bottleneck

这个函数其实只是读取存储的bottleneck文件数据,如果没有的话就立即创建。

下面是新增和修改的代码,可结合上一篇的代码运行测试:

#取得或创建瓶颈文件数据,如果没有就创建它。返回由bottleneck层产生的图片的numpy array数组
def get_or_create_bottleneck(sess, label_name, category, index, jpeg_data_tensor,
                             decoded_image_tensor, resized_input_tensor,
                             bottleneck_tensor):
    label_lists = image_lists[label_name]
    sub_dir = label_lists['dir'] #获取花分类名如'daisy'
    sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
    ensure_dir_exists(sub_dir_path) #确保路径文件夹存在    
    bottleneck_path = get_bottleneck_path(label_name, category,index)
    
    if not os.path.exists(bottleneck_path): #如果文件不存在就创建文件
        create_bottleneck_file(sess, label_name, category,index,jpeg_data_tensor,
                               decoded_image_tensor, resized_input_tensor,
                               bottleneck_tensor)    
    
    with open(bottleneck_path, 'r') as bottleneck_file: #读取瓶颈文件
        bottleneck_string = bottleneck_file.read()
    
    did_hit_error = False #遇到错误
    try:
        bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
    except ValueError:
        tf.logging.warning('在重建瓶颈文件时遇到非法浮点数')
        did_hit_error = True    
    
    if did_hit_error: #如果出错就重建瓶颈文件
        create_bottleneck_file(sess, label_name, category,index,jpeg_data_tensor,
                               decoded_image_tensor, resized_input_tensor,
                               bottleneck_tensor)        
        with open(bottleneck_path, 'r') as bottleneck_file:
            bottleneck_string = bottleneck_file.read() 
        bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
        
    return bottleneck_values

#入口函数
def main(_):
    module_spec = hub.load_module_spec(HUB_MODULE)
    graph, bottleneck_tensor, resized_input_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
        get_or_create_bottleneck(sess,'daisy','training', 65, jpeg_data_tensor,
                               decoded_image_tensor, resized_input_tensor,
                               bottleneck_tensor)

如果我们打印bottleneck_values就会得到一长串数字组成的数组。

[0.070175596, 0.2166954, 0.0072527127, 0.04728513, 1.1940469, 0.7925658, 2.029932, ...]

确保瓶颈文件都被缓存的函数cache_bottlenecks

因为在训练过程中,对同一个图片会反复多次读取(不对图像进行扭曲处理的话),如果我们对图片bottleneck缓存就能大大提高效率。

我们将用这个函数检查所有图片进行计算并保存。这个函数其实也只是循环调用前面的get_or_create_bottleneck函数。

下面是增加和修改的代码,可以运行测试:

#确保所有的training、testing、validation要用的bottleneck文件都已经被缓存
def cache_bottlenecks(sess,jpeg_data_tensor, decoded_image_tensor,
                      resized_input_tensor, bottleneck_tensor):
    how_many_bottlenecks = 0
    ensure_dir_exists(bottleneck_dir)
    for label_name, label_lists in image_lists.items():
        for category in ['training', 'testing', 'validation']:
            category_list = label_lists[category] #针对每一个分类,比如daisy
            for index, unused_base_name in enumerate(category_list): #创建索引
                get_or_create_bottleneck(
                    sess, label_name, category,index,
                    jpeg_data_tensor, decoded_image_tensor,
                    resized_input_tensor, bottleneck_tensor)

                how_many_bottlenecks += 1
                if how_many_bottlenecks % 100 == 0:
                    tf.logging.info(str(how_many_bottlenecks) + '瓶颈文件被创建.') #每100张输出一次提示

#入口函数
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    module_spec = hub.load_module_spec(HUB_MODULE)
    graph, bottleneck_tensor, resized_input_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
        
        cache_bottlenecks(sess, jpeg_data_tensor,decoded_image_tensor, 
                          resized_input_tensor,bottleneck_tensor)

运行过程中会隔一会(处理100张)输出一行提示。整个过程可能需要十几分钟或更久,全部完成后会在/bottlenecks/文件夹下增加每个花类别的文件夹并且里面包含了很多很多txt文件。


随机获取一批bottleneck文件数据get_random_cached_bottlenecks

从所有的分类图片中随机选取一些bottleneck数据,主要是使用了get_or_create_bottleneck函数。

以下是新增和修改的代码,可以结合前面的代码运行测试:

#随机获取所有种类中随机bottleneck数据列表、对应的label_index和图片文件路径列表,
#how_many数量小于等于0时候获取全部
def get_random_cached_bottlenecks(sess, how_many, category, 
                                  jpeg_data_tensor,decoded_image_tensor,
                                  resized_input_tensor,bottleneck_tensor):
    class_count = len(image_lists.keys()) #有多少种花分类
    bottlenecks = []
    ground_truths = [] #label_index花分类索引号
    filenames = [] #图片文件路径列表
    
    if how_many >= 0:
        for unused_i in range(how_many):
            label_index = random.randrange(class_count) #随机一种花如daisy
            label_name = list(image_lists.keys())[label_index] #daisy
            image_index = random.randrange(MAX_IPC + 1) #每种类最大数量,如果超过后面会自动取余数
            image_name = get_image_path(label_name,category,image_index) #图片路径
            bottleneck = get_or_create_bottleneck( #读取bottleneck文件数据
              sess,  label_name, category, image_index,
              jpeg_data_tensor, decoded_image_tensor,
              resized_input_tensor, bottleneck_tensor)
            bottlenecks.append(bottleneck)
            ground_truths.append(label_index)
            filenames.append(image_name)
    else:
        for label_index, label_name in enumerate(image_lists.keys()):
            for image_index, image_name in enumerate(image_lists[label_name][category]): #建立某分类下图片索引
                image_name = get_image_path(label_name, category,image_index)
                bottleneck = get_or_create_bottleneck(
                    sess, label_name, category,image_index,  
                    jpeg_data_tensor, decoded_image_tensor,
                    resized_input_tensor, bottleneck_tensor)
                bottlenecks.append(bottleneck)
                ground_truths.append(label_index)
                filenames.append(image_name)
    
    return bottlenecks, ground_truths, filenames                   
                    
#入口函数
def main(_):
    tf.logging.set_verbosity(tf.logging.WARN)
    module_spec = hub.load_module_spec(HUB_MODULE)
    graph, bottleneck_tensor, resized_input_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
        
        result=get_random_cached_bottlenecks(sess, 5, 'training',
                                      jpeg_data_tensor,decoded_image_tensor,
                                      resized_input_tensor,bottleneck_tensor)

打印出result结果类似:

(
    [[0.0, 1.8930491, 0.0, 0.0,... 0.1771909, 0.045966692],...x5],#5个bottleneck数据
    [0, 3, 3, 0, 2], #对应的5个标签编号
    ['/Users/zhyuzh/desktop/MyProjects/.../flower_photos/daisy/16161045294_70c76ce846_n.jpg', ...x5] #对应的5张图片路径
)

随机获取变形的瓶颈数据get_random_distorted_bottlenecks

如果我们使用变形的图片进行训练,比如裁剪、放缩、翻转的图片,我们需要针对每个图片重新计算整个模型,所以我们不能使用原来缓存的图片bottleneck数据,我们需要使用另外的变形计算图来运行得到新的变形bottleneck数据,然后再把它投入到整个计算图进行训练。

首先我们回顾run_bottleneck_on_image函数:

def run_bottleneck_on_image(sess,image_data, image_data_tensor,
                            decoded_image_tensor, resized_input_tensor,
                            bottleneck_tensor):
    resized_input_values = sess.run(decoded_image_tensor, #解码JPEG,调整大小,放缩像素值
                                    {image_data_tensor: image_data}) #feed_dict
    bottleneck_values = sess.run(bottleneck_tensor, #使用识别网络运行它
                                 {resized_input_tensor: resized_input_values}) #feed_dict
    bottleneck_values = np.squeeze(bottleneck_values) #去掉冗余的数组嵌套,简化形状
    return bottleneck_values

这里我们看到参数传递进来的image_data图像数据,然后被喂食feed_dict到graph的decoded_image_tensor子图中运行得到resized_input_values。
然后再把结果喂食到graph的子图bottleneck_tensor中得到进一步结果。

这里需要新增的get_random_distorted_bottlenecks方法不使用现成的图像数据参数,而是从随机文件中读取,然后使用distorted_image子图对图像进行处理,然后同样将结果喂食到子图distorted_image_tensor中,得到进一步结果。

整体上这两个函数的实现思路是一样的。下面是新增的代码,请勿运行,稍后和后面的函数一起测试:

#随机获取变形的瓶颈数据,返回bottlenecks数组和对应的label_index数组
def get_random_distorted_bottlenecks(sess, how_many, category,
                                     input_jpeg_tensor,distorted_image_tensor, 
                                     resized_input_tensor, bottleneck_tensor):
    class_count = len(image_lists.keys()) #有几种花分类
    bottlenecks = [] #变形后的瓶颈数据
    ground_truths = [] #label_index标签编号
    for unused_i in range(how_many):
        label_index = random.randrange(class_count) #随机一个花分类
        label_name = list(image_lists.keys())[label_index] #daisy
        image_index = random.randrange(MAX_IPC + 1) #随机一张图
        image_path = get_image_path(label_name, category,image_index)
        
        if not tf.gfile.Exists(image_path):
            tf.logging.fatal('文件不存在 %s', image_path)
            
        #下面两句可以参考run_bottleneck_on_image函数
        jpeg_data = tf.gfile.FastGFile(image_path, 'rb').read()  #没有参数传递jpeg_data进来,要重新读取文件       
        distorted_image_data = sess.run(distorted_image_tensor,
                                        {input_jpeg_tensor: jpeg_data}) #feed_dict
        bottleneck_values = sess.run(bottleneck_tensor,
                                     {resized_input_tensor: distorted_image_data}) #feed_dict
        bottleneck_values = np.squeeze(bottleneck_values)
        bottlenecks.append(bottleneck_values)
        ground_truths.append(label_index)
        
    return bottlenecks, ground_truths

生成变形图片操作ops的函数add_input_distortions

在训练的过程中我们对图片进行一些变形(裁切、放缩、翻转或调整亮度),可以利用有限数量的图片模拟更多的真实情况,进而有效改进模型。

这个函数里我们构建了新的graph用来对图片数据新建调整变换:

#生成两个变形操作ops的函数input_jpeg_tensor,distorted_image_tensor
#注意,只是生成一个grah并返回需要运行这个graph的两个feed_dict入口
def add_input_distortions(module_spec,flip_left_right, 
                          random_crop, random_scale,random_brightness):
    input_height, input_width = hub.get_expected_image_size(module_spec) #获取已有模型中的宽高要求
    input_depth = hub.get_num_image_channels(module_spec) #获取模型中图片通道深度数
    jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') #feed_dict输入口
    decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth) #读取图片数据
    decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,tf.float32) #数据类型转换
    decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) #升维
    
    #对图片数据进行裁切和放缩
    margin_scale = 1.0 + (random_crop / 100.0) #random_crop参数范围0~100
    resize_scale = 1.0 + (random_scale / 100.0) #random_scale参数范围0~100
    margin_scale_value = tf.constant(margin_scale)  #转为张量  
    resize_scale_value = tf.random_uniform(shape=[],minval=1.0,maxval=resize_scale) #转为张量
    scale_value = tf.multiply(margin_scale_value, resize_scale_value)
    precrop_width = tf.multiply(scale_value, input_width)
    precrop_height = tf.multiply(scale_value, input_height)
    precrop_shape = tf.stack([precrop_height, precrop_width])
    precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
    precropped_image = tf.image.resize_bilinear(decoded_image_4d,precrop_shape_as_int)
    precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
    cropped_image = tf.random_crop(precropped_image_3d,[input_height, input_width, input_depth])
    
    #对图片进行翻转
    if flip_left_right:
        flipped_image = tf.image.random_flip_left_right(cropped_image)
    else:
        flipped_image = cropped_image
        
    #调整图片亮度
    brightness_min = 1.0 - (random_brightness / 100.0) #random_brightness参数范围0~100
    brightness_max = 1.0 + (random_brightness / 100.0)
    brightness_value = tf.random_uniform(shape=[],minval=brightness_min,maxval=brightness_max)
    brightened_image = tf.multiply(flipped_image, brightness_value)
    distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
    
    return jpeg_data, distort_result
                    
#入口函数
def main(_):
    tf.logging.set_verbosity(tf.logging.WARN)
    module_spec = hub.load_module_spec(HUB_MODULE)
    graph, bottleneck_tensor, resized_input_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        input_jpeg_tensor,distorted_image_tensor = add_input_distortions(module_spec,True, 40, 60, 36)
        
        result=get_random_distorted_bottlenecks(sess, 5, 'training',
                                                input_jpeg_tensor,distorted_image_tensor,
                                                resized_input_tensor, bottleneck_tensor)

打印result得到类似下面的输出:

(
    [
        array([0.28729123, 1.5005591 ,  ...,  0.7306035], dtype=float32), 
        array([1.1294909 , 4.1503158 ,  ...,  0.02221665], dtype=float32), 
        array([0.8439883, 2.8368084 ,  ..., 0.04721354], dtype=float32), 
        array([0.00706462, 0.83330685,  ..., 0.99900544], dtype=float32), 
        array([0.00402025, 3.205397  , ..., 0.99900544,], dtype=float32)
    ], #5个被变形后的图片数据
    [2, 1, 3, 0, 0] #对应的5个花类型索引号
)

小结

本篇主要添加了以下几个函数用来深入处理bottleneck文件:

  • 取得或创建瓶颈文件的函数get_or_create_bottleneck
  • 确保瓶颈文件都被缓存的函数cache_bottlenecks
  • 随机获取一批bottleneck文件数据get_random_cached_bottlenecks
  • 随机获取一批变形的瓶颈数据get_random_distorted_bottlenecks
  • 生成变形图片操作ops的函数add_input_distortions

探索人工智能的新边界

如果您发现文章错误,请不吝留言指正;
如果您觉得有用,请点喜欢;
如果您觉得很有用,感谢转发~


END

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,734评论 6 505
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,931评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,133评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,532评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,585评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,462评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,262评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,153评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,587评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,792评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,919评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,635评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,237评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,855评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,983评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,048评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,864评论 2 354

推荐阅读更多精彩内容

  • 图像识别往往包含数以百万计的参数,从头训练需要大量打好标签的图片,还需要大量的计算力(往往数百小时的GPU时间)。...
    zhyuzh3d阅读 3,317评论 4 11
  • 姓名:尤学强 学号:17101223374 转载自:http://mp.weixin.qq.com/s/C6cID...
    51fb659a6d6f阅读 3,557评论 0 16
  • 简单粗暴地说,小图标,用png储存最好。 png可以储存透明,完爆gif的地方在于失真小,没锯齿;劣势是不支持动画...
    SpursGo阅读 2,206评论 0 0
  • 又是一年端午节,昨天特意带着孩子去体验一下这个节日。 吃完饭,九点多出发,走着到了公园,已经远远的听见鼓声,远远的...
    飞飞来啦阅读 193评论 0 0
  • 这个故事不过是青春里的一段往事,只是在某天悉数想起。 故事里有我曾经爱过的男孩,他的身影消失在2008年的最后一声...
    消灭神经病阅读 487评论 0 1