TensorFlow 训练多任务多标签模型

        在学习了使用 TensorFlow 的 CNN 进行图像分类之后,现在对这些方法做一个简单的拓展,即来处理多任务多标签的情形。为了便于说明,我们假设现在要对 0-9 这 10 个数字 和 A-Z (排除 I、O) 这 24 个字母进行识别,所有的数据都使用 captcha 生成(读过 TensorFlow 训练 CNN 分类器 这篇文章的读者应该不陌生了)。以下的代码(命名为 generate_train_data.py)使用 captcha 生成了 100000 万张 28 x 28 的图像,每张图像都是带有大量噪声的一个字符(所有字符见下面代码中的 alphabets 列表,所有的图像保存在文件夹 ./datasets/images 中,每张图像命名为 image图像序号_类标号.jpg,其中的类标号为该字符在列表 alphabets 中的下标)。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018

@author: shirhe-lyh
"""

import cv2
import numpy as np

from captcha.image import ImageCaptcha


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image
    
    
if __name__ == '__main__':
    output_dir = './datasets/images/'
    alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
                 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 
                 'U', 'V', 'W', 'X', 'Y', 'Z']
    for i in range(100000):
        label = np.random.randint(0, 34)
        image = generate_captcha(alphabets[label])
        image_name = 'image{}_{}.jpg'.format(i+1, label)
        output_path = output_dir + image_name
        cv2.imwrite(output_path, image)

我们的目的是训练一个简单的 CNN 模型将对这些图像进行分类,由于这个问题很简单,直接训练一个 34 类的分类器就达成目标了。但类别数越大,训练就越困难,因此我们采取另一种分化的策略,将这个 34 类的问题分为两个子问题,分别是:1.只识别数字;2.只识别字母。之所以可以这么分,是因为 数字字母 的差别很大,完全可以认为它们属于两种不同的范畴,从而可以看成独立的分类任务来处理。这样我们现在的问题是:怎样同时识别 10 个数字和 24 个字母?这是一个多任务多标签问题:我们要处理识别数字和识别字母这两个任务,其中每个任务都是涉及多个标签(分别是 10 个标签和 24 个标签)

        虽然这篇文章举例的这个问题非常简单,但这个方法(再加上预训练模型技巧)可以用于更加复杂的问题,比如 阿里的 FashionAI 服饰属性识别全球挑战赛,感兴趣的朋友可以用 ResNet-50 预训练模型去微调一个 8 任务模型。

        本文的所有代码见 github:multi_task_test,欢迎访问交流并反馈问题!

一、多分支 CNN 模型定义

        虽然我们要处理的是两个独立的任务,但我们希望这两个任务共用大部分的神经网络层,这样既可以节省计算量,一般来说,也可以提升准确率。因此,我们将要定义的神经网络结构设计为(所有共用的层在文章 TensorFlow-slim 训练 CNN 分类模型 中用来识别 0-9 这 10 个数字):

两分支输出的 CNN 用于识别数字和字母两个任务

当获取了一张图像(数字或字母)之后,将它送入第一个卷积层(conv1)、第二个卷积层(conv2)、······,直到第二个全连接层(fc2),到此为止,这些层都是两个任务共用的,它们的作用是用来提取图像特征。然后,针对两个不同的任务,将网络分为两个分支,一个用于输出该图像是各个数字的概率(digits_output),另一个用于输出该图像是各个字母的概率(letters_output)。网络的具体定义如下(网络各层的名字可能和上图不一致):

def predict(self, preprocessed_inputs):
    """Predict prediction tensors from inputs tensor.
        
    Outputs of this function can be passed to loss or postprocess functions.
        
   Args:
        preprocessed_inputs: A float32 tensor with shape [batch_size,
            height, width, num_channels] representing a batch of images.
            
    Returns:
        prediction_dict: A dictionary holding prediction tensors to be
            passed to the Loss or Postprocess functions.
    """
    net = preprocessed_inputs
    net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
    net = slim.max_pool2d(net, [2, 2], scope='pool1')
    net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
    net = slim.max_pool2d(net, [2, 2], scope='pool2')
    net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
    net = slim.flatten(net, scope='flatten')
    net = slim.dropout(net, keep_prob=0.5,
                       is_training=self._is_training)
    net = slim.fully_connected(net, 512, scope='fc1')
    net = slim.fully_connected(net, 512, scope='fc2')
    prediction_dict = {}
    for class_name, num_classes in self.num_classes_dict.items():
        logits = slim.fully_connected(net, num_outputs=num_classes, 
                                      activation_fn=None, 
                                      scope='Predict/' + class_name)
        prediction_dict[class_name] = logits
    return prediction_dict

        从以上代码可以看到,多任务多标签任务的 CNN 定义也非常简单,只需要引入一个 for 循环即可。接下来,要定义损失函数和准确率函数

        在生成图像的时候,图片名字命名的模式是 image图像序号_类标号.jpg,比如,假设第 1 张图像是字母 G,那么它的类标号是 16 = 10 + 7 - 1,因此它的名字是 image1_16.jpg。但这个类标号 16 是基于所有 34 个类来说的,实际上,如果只限于字母来说,它的类标号应该是 6。之所以对数字和字母使用统一的类标号,其实是为了便于定义损失和准确率函数。原因在于:对字母 G,因为我们现在是独立处理数字和字母两个分支任务,因此 G 应该只对分类字母的分支贡献损失,而不应当对分类数字的分支产生损失。如果统一对数字和字母分配类标号,那么 G 的类标号 16 的独热(one-hot)编码是:

0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

其中的 - 是为了便于看清两个任务的分界线,实际请忽略。此时,在计算损失时,将这个独热编码一分为二:

0 0 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

前一部分对应于 G 在分类 0-9 这 10 个数字的任务内的(非严格独热)编码,因为全部为 0,因此在计算分类交叉熵的时候损失为 0,这是我们期望的;后一部分恰好是 G 在分类 A-Z(排除 I、O)这 24 个字母的任务内的独热编码,正好用于计算分类交叉熵,也是我们期望的,可见统一分配类标号在计算损失时是非常方便的。了解了这一点之后,损失函数的定义如下:

 def loss(self, prediction_dict, groundtruth_lists):
    """Compute scalar loss tensors with respect to provided groundtruth.
        
    Args:
        prediction_dict: A dictionary holding prediction tensors.
        groundtruth_lists: A list of tensors holding groundtruth
            information, with one entry for each task.
                
    Returns:
        A dictionary mapping strings (loss names) to scalar tensors
            representing loss values.
    """
    onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
    for class_name in self.num_classes_dict:
        weights = tf.cast(tf.greater(
            tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
            dtype=tf.float32)
        slim.losses.softmax_cross_entropy(
            logits=prediction_dict[class_name], 
            onehot_labels=onehot_labels_dict[class_name],
            weights=weights,
            scope='Loss/' + class_name)
    loss = slim.losses.get_total_loss()
    loss_dict = {'loss': loss}
    return loss_dict
    
def _onehot_groundtruth_dict(self, groundtruth_lists):
    """Transform groundtruth lables to one-hot formats.
        
    Args:
        groundtruth_lists: A dict of tensors holding groundtruth
            information, with one entry for task.
                
    Returns:
        onehot_labels_dict: A dictionary mapping strings (class names) 
            to one-hot lable tensors.
    """
    one_hot = tf.one_hot(
        groundtruth_lists, depth=sum(self.num_classes_dict.values()))
    onehot_labels_dict = {}
    start_index = 0
    for class_name in self._class_order:
        onehot_labels_dict[class_name] = tf.slice(
            one_hot, [0, start_index], 
            [-1, self.num_classes_dict[class_name]])
        start_index += self.num_classes_dict[class_name]
    return onehot_labels_dict

其中,函数 _onehot_groundtruth_dict 用于将统一分配的类标号对应的独热编码分为数字和字母这两个任务对应的两个独热编码,之后的 loss 函数就可以用来计算正常的分类交叉熵损失。为了确保全 0 的独热编码对应 0 的损失,定义了 weights 这一个变量,它的作用是:当编码为全 0 时,该样本对应的损失权重为 0,因此贡献的损失为 0,即不属于这个分类任务的样本对这个分类任务的损失贡献为 0(虽然理论上全 0 的独热编码对应的分类交叉熵为 0,但为了确保这点而不出现意外,weights 是非常必要的)。

        至于,准确率函数的定义则更简单,想法如下:当一张图像经过神经网络预测后,我们得到两个分支任务的概率输出,我们不关心它来源于哪个任务,因为这不影响准确率的计算;分别对两个任务的概率输出取 tf.argmax 得到在每个任务内的预测类标号,然后对这两个预测的类标号再计算它在对应任务内的独热编码,把这两个独热编码与上面计算损失时切割得到的两个独热编码分别按对应元素求和,如果求和结果中出现 2 说明预测结果正确,否则错误;对一个批量中的所有图像累计处理之后,即可算出准确率。继续上面的例子,前面已经说过,G 的类标号 16 对应的独热编码一分为二的结果为:

0 0 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

假如现在神经网络的两个分支预测的类标号分别为 1 和 6,那么它们分别对应独热编码:

0 1 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

以上独热编码按位置对应相加,得到:

0 1 0 0 0 0 0 0 0 0                        0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

前一个结果(即 0 1 0 0 0 0 0 0 0 0)所有位置上都没有出现 2,说明预测和实际的类标号没有重合,对准确率没有产生作用;后一个结果(即 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0) 中,第 6 个索引位置出现 2 说明预测和实际的类标号是一样的,因此预测正确,预测正确数加 1。显然,每一张图像要么加 0 (两个任务都预测错误)要么加 1(其中一个任务预测正确),因此这样计算准确率是正确的(不可能加 2,因为实际的两个独热编码中,其中的一个全是 0)。详细的细节请参考如下完整代码(将其命名为 model.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from abc import ABCMeta
from abc import abstractmethod

slim = tf.contrib.slim


class BaseModel(object):
    """Abstract base class for any model."""
    __metaclass__ = ABCMeta
    
    def __init__(self, num_classes_dict):
        """Constructor.
        
        Args:
            num_classes: Number of classes.
        """
        self._num_classes_dict = num_classes_dict
        
    @property
    def num_classes_dict(self):
        return self._num_classes_dict
    
    @abstractmethod
    def preprocess(self, inputs):
        """Input preprocessing. To be override by implementations.
        
        Args:
            inputs: A float32 tensor with shape [batch_size, height, width,
                num_channels] representing a batch of images.
            
        Returns:
            preprocessed_inputs: A float32 tensor with shape [batch_size, 
                height, widht, num_channels] representing a batch of images.
        """
        pass
    
    @abstractmethod
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        pass
    
    @abstractmethod
    def postprocess(self, prediction_dict, **params):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        pass
    
    @abstractmethod
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        pass
    
        
class Model(BaseModel):
    """xxx definition."""
    
    def __init__(self,
                 is_training,
                 num_classes_dict={'digits': 10, 'letters': 24}):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        super(Model, self).__init__(num_classes_dict=num_classes_dict)
        
        self._is_training = is_training
        self._class_order = ['digits', 'letters']
        
    def preprocess(self, inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        preprocessed_inputs = tf.to_float(inputs)
        preprocessed_inputs = tf.subtract(preprocessed_inputs, 128.0)
        preprocessed_inputs = tf.div(preprocessed_inputs, 128.0)
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        net = preprocessed_inputs
        net = slim.repeat(net, 2, slim.conv2d, 32, [3, 3], scope='conv1')
        net = slim.max_pool2d(net, [2, 2], scope='pool1')
        net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2')
        net = slim.max_pool2d(net, [2, 2], scope='pool2')
        net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3')
        net = slim.flatten(net, scope='flatten')
        net = slim.dropout(net, keep_prob=0.5,
                           is_training=self._is_training)
        net = slim.fully_connected(net, 512, scope='fc1')
        net = slim.fully_connected(net, 512, scope='fc2')
        prediction_dict = {}
        for class_name, num_classes in self.num_classes_dict.items():
            logits = slim.fully_connected(net, num_outputs=num_classes, 
                                          activation_fn=None, 
                                          scope='Predict/' + class_name)
            prediction_dict[class_name] = logits
        return prediction_dict
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        postprecessed_dict = {}
        for class_name in self.num_classes_dict:
            logits = prediction_dict[class_name]
#            logits = tf.nn.softmax(logits, name=class_name)
            postprecessed_dict[class_name] = logits
        return postprecessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each task.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
        for class_name in self.num_classes_dict:
            weights = tf.cast(tf.greater(
                tf.reduce_sum(onehot_labels_dict[class_name], axis=1), 0),
                dtype=tf.float32)
            slim.losses.softmax_cross_entropy(
                logits=prediction_dict[class_name], 
                onehot_labels=onehot_labels_dict[class_name],
                weights=weights,
                scope='Loss/' + class_name)
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
    
    def _onehot_groundtruth_dict(self, groundtruth_lists):
        """Transform groundtruth lables to one-hot formats.
        
        Args:
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for task.
                
        Returns:
            onehot_labels_dict: A dictionary mapping strings (class names) 
                to one-hot lable tensors.
        """
        one_hot = tf.one_hot(
            groundtruth_lists, depth=sum(self.num_classes_dict.values()))
        onehot_labels_dict = {}
        start_index = 0
        for class_name in self._class_order:
            onehot_labels_dict[class_name] = tf.slice(
                one_hot, [0, start_index], 
                [-1, self.num_classes_dict[class_name]])
            start_index += self.num_classes_dict[class_name]
        return onehot_labels_dict
    
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        onehot_labels_dict = self._onehot_groundtruth_dict(groundtruth_lists)
        num_corrections = 0.
        for class_name in self.num_classes_dict:
            predicted_argmax = tf.argmax(tf.nn.softmax(
                postprocessed_dict[class_name]), axis=1)
            onehot_predicted = tf.one_hot(
                predicted_argmax, depth=self.num_classes_dict[class_name])
            onehot_sum = tf.add(onehot_labels_dict[class_name],
                                onehot_predicted)
            correct = tf.greater(onehot_sum, 1)
            num = tf.reduce_sum(tf.cast(correct, tf.float32))
            num_corrections += num
        total_nums = tf.cast(tf.shape(groundtruth_lists)[0], dtype=tf.float32)
        accuracy = num_corrections / total_nums
        return accuracy 

        在定义 postprocess 函数时,我把语句:

logits = tf.nn.softmax(logits, name=class_name)

注释掉了(这显得这个函数没有任何用处),我的本意是为了观察 predict 函数中两个网络分支的最本原输出,主要考虑的是:当一张图片送到网络入口时,如果根本不知道它是数字还是字母,那么经过神经网络处理后,我们面临着两个任务的输出,要怎么判断它属于哪个任务中的哪个标签呢?如果我们已经知道这张图像来源于其中某一个任务,比如来源于数字任务,那么直接对数字任务分支的输出取 tf.argmax 就知道它对应的预测标签了。但现在的关键问题是,如果不知道它属于其中哪个任务,能否根据两个分支的输出直接判断出来呢?答案是可以的,尽管这是基于经验观察的。通过模型训练并导出为 .pb 文件之后,运行 evaluate.py 文件(很多次),可以观察两个分支的直接输出,你会发现两个任务中所有这些输出的最大值对应的标签就是网络的预测输出,也就是说:可以通过比较两个任务的所有输出,来预测图像来源于哪个任务(进而预测属于哪个标签)——所有输出的值中,最大值所在的任务就可以认为是图像来源的任务

二、模型训练与保存

        因为模型训练的代码和文章 TensorFlow-slim 训练 CNN 分类模型(续)train.py 的是一样的,这里直接忽略(也可以访问 github:multi_task_test 获取本文所有代码)。

        当你获取到代码后,首先在项目当前目录下新建文件夹 datasets/images,然后在当前目录下的终端运行

python3 generate_train_data.py

生成 100000 张训练图像。之后,继续运行

python3 generate_tfrecord.py \
    --images_path ./datasets/images/ \
    --output_path ./datasets/train.record

得到训练的 .record 文件。 此时,在项目目录下再新建文件夹 training,接着在终端执行如下命令

python3 train.py --record_path ./datasets/train.record --logdir ./training/

便开始了训练过程。如果你要可视化的观看损失和准确率的变化情况,在当前目录下的终端执行

tensorboard --logdir ./training/

得到本地浏览器链接,打开这个链接即可监控训练的全过程。比如,我训练 5000 多次之后,准确率和损失的图像如下:

Tensorboard 显示的准确率和损失曲线

        当你觉得训练的准确率已经足够高了,并且文件夹 training 中也保存好了当前训练次数的模型文件之后,使用 Ctrl + C 中断训练过程。接下来,就是将 training 中的训练模型文件 .ckpt 转化为 .pb 文件,然后测试训练效果了。有关自定义的将 .ckpt 格式转化为 .pb 格式的模型文件请访问文章 TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式。在那篇文章中,已经指出,需要针对不同的分类模型做出改变的地方主要是包含 model 参数的那些函数,尤其是由输入得到输出的函数 _add_output_tensor_nodes。比如,我们这篇文章有两个分支任务的输出,对应的函数 _add_output_tensor_nodes 修改为:

def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            'classes': [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    for class_name, logits in postprocessed_tensors.items():
        outputs[class_name] = tf.identity(logits, name=class_name)
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs

其它函数不需要修改,完整文件请查看 github:multi_task_testexport.py 文件。然后,在项目的当前目录终端执行模型导出命令:

python3 export_inference_graph.py \
    --trained_checkpoint_prefix ./training/model.ckpt-5265 \
    --output_directory ./training/inference_graph_pb

你会在 training 文件夹中看到一个新的文件夹 inference_graph_pb,里面的文件 frozen_inference_graph.pb 就是我们用来做模型推断的文件。上面一条命令中的 model.ckpt-5265 请根据你自己的训练情况做修改,这里我是只训练了 5000 多次,然后使用训练了 5265 次的模型用于图像推断。

        当你一切都顺利执行之后,恭喜你来到最后一步,是时候验证一下你训练的模型的效果了。写个简单的模型验证文件 evaluate.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 14:02:05 2018

@author: shirhe-lyh
"""

"""Evaluate the trained CNN model.

Example Usage:
---------------
python3 evaluate.py \

    --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    alphabets = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
                 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 
                 'U', 'V', 'W', 'X', 'Y', 'Z']
    
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name('image_tensor:0')
            digits = model_graph.get_tensor_by_name('digits:0')
            digit_classes = tf.argmax(tf.nn.softmax(digits), axis=1)
            letters = model_graph.get_tensor_by_name('letters:0')
            letter_classes = tf.argmax(tf.nn.softmax(letters), axis=1)
            for i in range(10):
                label = np.random.randint(0, 34)
                image = generate_captcha(alphabets[label])
                image_np = np.expand_dims(image, axis=0)
                predicted_ = sess.run([digits, digit_classes,
                                       letters, letter_classes], 
                                           feed_dict={inputs: image_np})
                predicted_digits = np.round(predicted_[0], 2)
                predicted_digit_classes = predicted_[1]
                predicted_letters = np.round(predicted_[2], 2)
                predicted_letter_classes = predicted_[3]
                print(predicted_digits, '----', predicted_digit_classes)
                print(predicted_letters, '----', predicted_letter_classes)
                predicted_label = predicted_letter_classes[0] + 10
                if label < 10:
                    predicted_label = predicted_digit_classes[0]
                print(alphabets[predicted_label], ' vs ', alphabets[label])
            
            
if __name__ == '__main__':
    tf.app.run()

在终端执行如下命令,进行模型评估:

python3 evaluate.py \
    --frozen_graph_path ./training/inference_graph_pb/frozen_inference_graph.pb

你可以仔细的观察最后两个分支的直接输出,看看最大值对应的那个任务是否恰好是验证图像实际来源的任务。

预告:下一篇文章将要介绍如何用 TensorFlow 实现 生成对抗网络,敬请期待!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,837评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,551评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,417评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,448评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,524评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,554评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,569评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,316评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,766评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,077评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,240评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,912评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,560评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,176评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,425评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,114评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,114评论 2 352

推荐阅读更多精彩内容