3.基于LSTM+CTC实现不定长文本图片OCR

上一篇实现了图片CNN多标签分类(4位定长验证码识别任务)
(地址:https://www.jianshu.com/p/596db72a7e00

本文继续优化,实现不定长文本图片的识别任务

下一篇考虑玩一玩GAN网络

本文所用到的10w不定长验证码文本数据集百度网盘下载地址(也可使用下文代码自行生成):https://pan.baidu.com/s/11BzIvuT4pYw3B0aFCK0ndQ

利用本文代码训练并生成的模型(对应项目中的my-model文件夹):https://pan.baidu.com/s/1AoKtZVyscWp3ZdOQU71qLA

项目简介:
需要预先安装pip install captcha==0.1.1,pip install opencv-python,pip install flask, pip install tensorflow/pip install tensorflow-gpu)
本文采用LSTM+CTC实现1-10位不定长验证码图片OCR(生成的验证码由随机的1-10位大写字母组成),本质上是一张图片多个标签的分类问题,且每个图片的标签数量不固定(数据如下图所示)


0_PIY.png

1_BCAVDPXT.png

2_N.png

整体训练逻辑:
1,将图像传入到LSTM中获得sequence,和sequence的长度(大致的原理是:将图像的width看做LSTM中的time_step,将图像的height看做每个time_step输入tensor的size)
2,将真实的y_label转为稀疏矩阵张量(此处的sparseTensor是个重点,同学们可以把代码中的153行y_train_tmp打印出来观察一下)
3,损失函数采用tf.nn.ctc_loss,然后对以上两步获得的数据进行训练,最终使得损失函数尽可能的减小

关于ctc_loss的原理可以百度科普一下,它的主要作用可以大概理解为将上层网络预测出的AAABBBBCCDEE收敛成ABBCDE,这里面牵涉到AAA到底收敛为几个A,BBBB又收敛为几个B,这也是他的核心

整体预测逻辑:
1,将图像传入到LSTM中获得sequence,和sequence的长度
2,将sequence,sequence的长度输入到tf.nn.ctc_beam_search_decoder函数预测出稀疏矩阵张量
3,将第二步得到的稀疏矩阵张量反向转化为sequence,并最终解码成A~Z的大写字母并输出

后续优化逻辑:
1,可以在LSTM之前先采用CNN对图像特征进行一次提取
2,TF自带的ctc_loss可以换成百度开源的Warp_CTC
3,针对少量原始图片为AAA结果最终识别为AA,丢掉了一个A的情况,是否可以把原先的标签['A', 'A', 'A']扩充为['A-left', 'A-middle', 'A-right', 'A-left', 'A-middle', 'A-right', 'A-left', 'A-middle', 'A-right']将每个字由原先的1个标签扩充为三个标签,此处抛砖引玉,可以自行尝试优化

优缺点:
1,LSTM+CTC考虑了一行文本从左到右的序列关系,这一点上比CNN更强,同时可以轻松实现不定长的OCR
2,也正是由于RNN网络考虑了时序间的关系,所以运算量相对于CNN网络大幅增加,收敛比较慢,有条件的同学还是上一块好点的GPU吧,能提升很多效率

运行命令:
自行生成验证码训练寄(本文生成了10w张,修改self.im_total_num变量):
python LstmCtcOcr.py create_dataset
对数据集进行训练:python LstmCtcOcr.py train
对新的图片进行测试:python LstmCtcOcr.py test
启动成http服务:python LstmCtcOcr.py start
利用flask框架将整个项目启动成web服务,使得项目支持http方式调用
启动服务后调用以下地址测试
http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/0_PIY.png
http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/1_BCAVDPXT.png
http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/2_N.png

项目目录结构:


项目结构.png

训练200个epoch之后,可以看到model在val上的acc已经能达到84%了,后续大家可以自行修改学习率和增大epoch次数来提升精度(True表示预测正确,左边为预测值,右边为真实标签):


lstm-ctc-199-epoch.png

整体代码如下(LstmCtcOcr.py文件):

# coding:utf-8

from captcha.image import ImageCaptcha
import numpy as np
import cv2
import tensorflow as tf
import random, os, sys
import operator


from flask import request
from flask import Flask
import json
app = Flask(__name__)

class LstmCtcOcr:
    def __init__(self):
        self.epoch_max = 200  # 最大迭代epoch次数
        self.batch_size = 16  # 训练时每个批次参与训练的图像数目,显存不足的可以调小
        self.lr = 5e-5  # 初始学习率
        self.save_epoch = 5  # 每相隔多少个epoch保存一次模型
        self.n_hidden = 256  # 隐藏神经元个数

        self.im_width = 256
        self.im_height = 64
        self.im_total_num = 100000  # 总共生成的验证码图片数量
        self.train_max_num = self.im_total_num  # 训练时读取的最大图片数目
        self.val_num = 30 * self.batch_size  # 不能大于self.train_max_num  做验证集用
        self.words_max_num = 10  # 每张验证码图片上的最大字母个数
        self.words = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
        self.n_classes = len(self.words) + 1  # 26个字母 + blank
        self.x = None
        self.y = None

    def captchaOcr(self, img_path):
        """
        验证码识别
        :param img_path:
        :return:
        """
        im = cv2.imread(img_path)
        im = cv2.resize(im, (self.im_width, self.im_height))
        im = np.array([im[:, :, 0]], dtype=np.float32)
        im -= 147
        pred = self.sess.run(self.pred, feed_dict={self.x: im})
        sequence = self.sparseTensor2sequence(pred)
        return ''.join(sequence[0])


    def test(self, img_path):
        """
        测试接口
        :param img_path:
        :return:
        """
        self.batch_size = 1
        self.learning_rate = tf.placeholder(dtype=tf.float32)  # 动态学习率
        self.weight = tf.Variable(tf.truncated_normal([self.n_hidden, self.n_classes], stddev=0.1))
        self.bias = tf.Variable(tf.constant(0., shape=[self.n_classes]))
        self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width])
        logits, seq_len = self.rnnNet(self.x, self.weight, self.bias)
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
        self.pred = tf.cast(decoded[0], tf.int32)

        saver = tf.train.Saver()
        # tfconfig = tf.ConfigProto(allow_soft_placement=True)
        # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.3  # 占用显存的比例
        # self.ses = tf.Session(config=tfconfig)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())  # 全局tf变量初始化

        # 加载w,b参数
        saver.restore(self.sess, './my-model/LstmCtcOcr-200')
        im = cv2.imread(img_path)
        im = cv2.resize(im, (self.im_width, self.im_height))
        im = np.array([im[:, :, 0]], dtype=np.float32)
        im -= 147
        pred = self.sess.run(self.pred, feed_dict={self.x: im})
        sequence = self.sparseTensor2sequence(pred)
        print(''.join(sequence[0]))


    def train(self):
        """
        训练
        :return:
        """
        x_train_list, y_train_list, x_val_list, y_val_list = self.getTrainDataset()

        print('开始转换tensor队列')
        x_train_list_tensor = tf.convert_to_tensor(x_train_list, dtype=tf.string)
        y_train_list_tensor = tf.convert_to_tensor(y_train_list, dtype=tf.int32)

        x_val_list_tensor = tf.convert_to_tensor(x_val_list, dtype=tf.string)
        y_val_list_tensor = tf.convert_to_tensor(y_val_list, dtype=tf.int32)

        x_train_queue = tf.train.slice_input_producer(tensor_list=[x_train_list_tensor], shuffle=False)
        y_train_queue = tf.train.slice_input_producer(tensor_list=[y_train_list_tensor], shuffle=False)

        x_val_queue = tf.train.slice_input_producer(tensor_list=[x_val_list_tensor], shuffle=False)
        y_val_queue = tf.train.slice_input_producer(tensor_list=[y_val_list_tensor], shuffle=False)

        train_im, train_label = self.dataset_opt(x_train_queue, y_train_queue)
        train_batch = tf.train.batch(tensors=[train_im, train_label], batch_size=self.batch_size, num_threads=2)

        val_im, val_label = self.dataset_opt(x_val_queue, y_val_queue)
        val_batch = tf.train.batch(tensors=[val_im, val_label], batch_size=self.batch_size, num_threads=2)

        print('准备训练')
        self.learning_rate = tf.placeholder(dtype=tf.float32)  # 动态学习率
        self.weight = tf.Variable(tf.truncated_normal([self.n_hidden, self.n_classes], stddev=0.1))
        self.bias = tf.Variable(tf.constant(0., shape=[self.n_classes]))

        # self.global_step = tf.Variable(0, trainable=False)  # 全局步骤计数

        # im_width看成LSTM的time_step ,im_height看成是每个time_step输入tensor的size
        self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width])
        # 定义ctc_loss需要的稀疏矩阵
        self.y = tf.sparse_placeholder(tf.int32)

        logits, seq_len = self.rnnNet(self.x, self.weight, self.bias)

        # loss
        self.loss = tf.nn.ctc_loss(self.y, logits, seq_len)
        # cost
        self.cost = tf.reduce_mean(self.loss)
        # optimizer
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)


        # 前面说的划分块之后找每块的类属概率分布,ctc_beam_search_decoder方法,是每次找最大的K个概率分布
        # 还有一种贪心策略是只找概率最大那个,也就是K=1的情况ctc_ greedy_decoder
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
        self.pred = tf.cast(decoded[0], tf.int32)
        self.distance = tf.reduce_mean(tf.edit_distance(self.pred, self.y))

        print('开始训练')
        saver = tf.train.Saver()  # 保存tf模型
        with tf.Session() as self.sess:
            self.sess.run(tf.global_variables_initializer())
            coordinator = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self.sess, coord=coordinator)

            batch_max = len(x_train_list) // self.batch_size
            print('batch:', batch_max)
            total_step = 0
            for epoch_num in range(self.epoch_max):
                lr_tmp = self.lr * (1 - (epoch_num / self.epoch_max) ** 2)  # 动态学习率
                print('lr:', lr_tmp)
                for batch_num in range(batch_max):
                    # print(epoch_num, batch_num)
                    x_train_tmp, y_train_tmp = self.sess.run(train_batch)
                    y_train_tmp = self.sequence2sparseTensor(y_train_tmp)  # 将labels转为稀疏矩阵张量
                    self.sess.run(self.optimizer, feed_dict={self.x: x_train_tmp, self.y: y_train_tmp, self.learning_rate: lr_tmp})

                    if total_step % 100 == 0 or total_step == 0:
                        print('epoch:%d/%d batch:%d/%d total_step:%d lr:%.10f' % (epoch_num, self.epoch_max, batch_num, batch_max, total_step, lr_tmp))
                        # train部分
                        train_loss, train_distance = self.sess.run([self.cost, self.distance], feed_dict={self.x: x_train_tmp, self.y: y_train_tmp})

                        # val部分
                        val_loss_list, val_distance_list, val_acc_list = [], [], []
                        for i in range(int(self.val_num / self.batch_size)):
                            x_val_tmp, y_val_tmp_true = self.sess.run(val_batch)
                            y_val_tmp = self.sequence2sparseTensor(y_val_tmp_true)  # 将labels转为稀疏矩阵张量
                            val_loss, val_distance, val_pred = self.sess.run([self.cost, self.distance, self.pred], feed_dict={self.x: x_val_tmp, self.y: y_val_tmp})
                            val_loss_list.append(val_loss)
                            val_distance_list.append(val_distance)
                            val_sequence = self.sparseTensor2sequence(val_pred)
                            ok = 0.
                            for idx, val_seq in enumerate(val_sequence):
                                val_pred_tmp = [self.words.find(x) if self.words.find(x) > -1 else 26 for x in val_seq]
                                val_y_true_tmp = [x for x in y_val_tmp_true[idx] if x != 26]

                                is_eq = operator.eq(val_pred_tmp, val_y_true_tmp)

                                if idx == 0:
                                    print(is_eq, [self.words[n] for n in val_pred_tmp], '<<==>>', [self.words[n] for n in val_y_true_tmp])

                                if is_eq:
                                    ok += 1
                            val_acc_list.append(ok / len(val_sequence))
                        val_acc_list = np.array(val_acc_list, dtype=np.float32)

                        print('train_loss:%.10f train_distance:%.10f' % (train_loss, train_distance))
                        print('  val_loss:%.10f   val_distance:%.10f val_acc:%.10f' % (np.mean(val_loss_list), np.mean(val_distance_list), np.mean(val_acc_list)))
                        print()
                        print()

                    total_step += 1

                # 保存模型
                if (epoch_num + 1) % self.save_epoch == 0:
                    saver.save(self.sess, './my-model/LstmCtcOcr', global_step=(epoch_num + 1))

            coordinator.request_stop()
            coordinator.join(threads)


    def rnnNet(self, inputs, weight, bias):
        """
        获取LSTM网络结构
        :param inputs:
        :param weight:
        :param bias:
        :return:
        """
        # 对于tf.nn.dynamic_rnn,默认time_major=false,此时inputs的shape=[batch_size, max_time_steps, features]
        # (batch_size, im_height, im_width) ==> (batch_size, im_width, im_height)
        inputs = tf.transpose(inputs, [0, 2, 1])

        # 变长序列的最大值
        # seq_len = np.ones(self.batch_size) * self.im_width
        seq_len = np.ones(self.batch_size) * self.im_width

        cell = tf.nn.rnn_cell.LSTMCell(self.n_hidden, forget_bias=0.8, state_is_tuple=True)

        # 动态rnn实现输入变长
        outputs1, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)

        # (self.batch_size * self.im_width, self.hidden)
        outputs = tf.reshape(outputs1, [-1, self.n_hidden])

        logits = tf.matmul(outputs, weight) + bias  # w * x + b
        logits = tf.reshape(logits, [self.batch_size, -1, self.n_classes])
        logits = tf.transpose(logits, (1, 0, 2))  # (im_width, batch_size, im_height)
        return logits, seq_len


    def sequence2sparseTensor(self, sequences, dtype=np.int32):
        """
        序列 转化为 稀疏矩阵
        :param sequences:
        :param dtype:
        :return:
        """
        values, indices= [], []
        for n, seq in enumerate(sequences):
            indices.extend(zip([n] * len(seq), range(len(seq))))
            values.extend(seq)
        indices = np.asarray(indices, dtype=np.int64)
        values = np.asarray(values, dtype=dtype)
        shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
        return indices, values, shape


    def sparseTensor2sequence(self, sparse_tensor):
        """
        稀疏矩阵 转化为 序列
        :param sparse_tensor:
        :return:
        """
        decoded_indexes = list()
        current_i = 0
        current_seq = []
        for offset, i_and_index in enumerate(sparse_tensor[0]):
            i = i_and_index[0]
            if i != current_i:
                decoded_indexes.append(current_seq)
                current_i = i
                current_seq = list()
            current_seq.append(offset)
        decoded_indexes.append(current_seq)
        result = []
        for index in decoded_indexes:
            result.append(self.sequence2words(index, sparse_tensor))
        return result


    def sequence2words(self, indexes, spars_tensor):
        """
        序列 转化为 文本
        :param indexes:
        :param spars_tensor:
        :return:
        """
        decoded = []
        for m in indexes:
            str_tmp = self.words[spars_tensor[1][m]]
            decoded.append(str_tmp)
        return decoded


    def dataset_opt(self, x_train_queue, y_train_queue):
        """
        处理图片和标签
        :param queue:
        :return:
        """
        queue = x_train_queue[0]
        contents = tf.read_file('./dataset/train/' + queue)
        im = tf.image.decode_jpeg(contents)
        tf.image.rgb_to_grayscale(im)
        im = tf.image.resize_images(images=im, size=[self.im_height, self.im_width])
        im = tf.reshape(im[:, :, 0], tf.stack([self.im_height, self.im_width]))
        im -= 147  # 去均值化
        return im, y_train_queue[0]


    def getTrainDataset(self):
        train_data_list = os.listdir('./dataset/train/')
        print('共有%d张训练图片, 读取%d张:' % (len(train_data_list), self.train_max_num))
        random.shuffle(train_data_list)  # 打乱顺序

        y_val_list, y_train_list = [], []
        x_val_list = train_data_list[:self.val_num]
        for x_val in x_val_list:
            words_tmp = x_val.split('.')[0].split('_')[1]
            words_tmp = words_tmp + '?' * (self.words_max_num - len(words_tmp))
            y_val_list.append([self.words.find(x) if self.words.find(x) > -1 else 26 for x in words_tmp])

        x_train_list = train_data_list[self.val_num:self.train_max_num]
        for x_train in x_train_list:
            words_tmp = x_train.split('.')[0].split('_')[1]
            words_tmp = words_tmp + '?' * (self.words_max_num - len(words_tmp))
            y_train_list.append([self.words.find(x) if self.words.find(x) > -1 else 26 for x in words_tmp])

        return x_train_list, y_train_list, x_val_list, y_val_list


    def createCaptchaDataset(self):
        """
        生成训练用图片数据集
        :return:
        """
        image = ImageCaptcha(width=self.im_width, height=self.im_height, font_sizes=(56,))
        for i in range(self.im_total_num):
            words_tmp = ''
            for j in range(random.randint(1, self.words_max_num)):
                words_tmp = words_tmp + random.choice(self.words)
            print(words_tmp, type(words_tmp))
            im_path = './dataset/train/%d_%s.png' % (i, words_tmp)
            print(im_path)
            image.write(words_tmp, im_path)




if __name__ == '__main__':
    opt_type = sys.argv[1:][0]

    instance = LstmCtcOcr()

    if opt_type == 'create_dataset':
        instance.createCaptchaDataset()
    elif opt_type == 'train':
        instance.train()
    elif opt_type == 'test':
        instance.test('./dataset/test/0_PIY.png')
    elif opt_type == 'start':
        # 将session持久化到内存中
        instance.test('./dataset/test/0_PIY.png')

        # 启动web服务
        # http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/1_BCAVDPXT.png
        @app.route('/captchaOcr', methods=['GET'])
        def captchaOcr():
            img_path = request.args.to_dict().get('img_path')
            print(img_path)
            ret = instance.captchaOcr(img_path)
            print(ret)
            return json.dumps({'img_path': img_path, 'ocr_ret': ret})

        app.run(host='0.0.0.0', port=5050, debug=False)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,921评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,635评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,393评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,836评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,833评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,685评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,043评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,694评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,671评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,670评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,779评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,424评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,027评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,984评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,214评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,108评论 2 351
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,517评论 2 343

推荐阅读更多精彩内容