使用tensorflow model maker训练目标检测模型

一、环境配置

1.1 使用conda创建一个新的隔离环境

因为我用的是conda环境,所以又新建了一个专门tensorflow model maker的环境

# 创建环境
conda create -n tf_model_maker python=3.9
# 激活环境
conda activate tf_model_maker
# 退出当前环境
conda deactivate
# 删除环境使用
conda remove -n tf_model_maker --all

1.2 配置tensorflow model maker环境

apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0

此处没有使用nightly版本,不知道是有什么bug,使用nightly版本有些库引用出问题了,所以换回非nightly版本

1.3 导包

import numpy as np
import os

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

执行输出:

/root/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

二、数据集整理

我使用的数据格式是coco格式的,已经处理成csv文件了,csv文件格式是:

filename,width,height,class,xmin,ymin,xmax,ymax
00232f5be5eb8a0f2c34a4a63f73d678.jpeg,683,1024,ball,224,756,511,1024
.....

目标csv数据格式:https://cloud.google.com/vision/automl/object-detection/docs/csv-format

set,path,label,xmin,ymin,,,xmax,ymax,,

  • TRAIN或者VAL或者TEST:训练数据、验证数据、测试数据标记

  • 图片文件全路径:此处必须要用全路径

  • label:标记名称

  • 图片中对象的边界框

    • 使用 2 个包含一组 x、y 坐标的顶点(如果这些点是矩形的对角点)(xmin, ymin,,,xmax,ymax,,)
    • 或使用全部 4 个顶点 (xmin,ymin,xmax,ymin,xmax,ymax,xmin,ymax)

    这些坐标必须是 0 到 1 范围内的浮点数,其中 0 表示最小 x 或 y 值,1 表示最大 x 或 y 值。

    例如,(0,0) 表示左上角,(1,1) 表示右下角;整个图片的边界框表示为 (0,0,,,1,1,,) 或 (0,0,1,0,1,1,0,1)

TRAIN,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
VAL,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
TEST,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,

数据处理代码:

import codecs
import csv
import cv2
import os

image_path = '/root/xxx/images/'

def makeData(old_file,new_file,key):

    file = open(new_file,'w')
    with file:
        w = csv.writer(file)

        with codecs.open(old_file, encoding='utf-8-sig') as f:
            for row in csv.DictReader(f, skipinitialspace=True):
                width=float(row['width'])
                height=float(row['height'])
                label=row['class']
                xmin=float(row['xmin'])/width
                ymin=float(row['ymin'])/height
                xmax=float(row['xmax'])/width
                ymax=float(row['ymax'])/height
                filename=row['filename']
                print(filename)
                img_path = os.path.join(image_path, filename)
                
                if os.path.exists(img_path) is True:
                    name = filename.replace(".jpeg","").replace(".jpg","")
                    save_path = os.path.join(image_path, name+".jpg")
                    img = cv2.imread(img_path)
                    cv2.imwrite(save_path,img)
                    new_row=[key,save_path,label,xmin,ymin,'','',xmax,ymax,'','',]
                    print(new_row)
                    w.writerow(new_row)

我拿到的图库中有些图片是直接修改的后缀,真实格式和后缀不同,也重新处理了一下,还有些图片不存在了,也过滤了一下

makeData('/root/xxx/train.csv',
         '/root/xxx/new_train.csv',
        'TRAIN')

makeData('/root/xxx/test.csv',
         '/root/xxx/new_test.csv',
        'TEST')

makeData('/root/xxx/test.csv',
         '/root/xxx/new_vaild.csv',
        'VAL')

然后我把new_train.csv、new_test.csv、new_vaild.csv中取了部分数据,手动合并到一个名为data.csv的文件里了

train_data,validation_data,test_data = object_detector.DataLoader.from_csv('/root/xxx/data.csv')

三、准备预训练模型

由于物体检测模型只支持EfficientDet系列的模型,我试过EfficientDet-Lite2发现在手机端的速度不是很理想,高端机差不多需要100ms左右识别出来,最终选择了速度更快的EfficientDet-Lite0

Model architecture Size(MB)* Latency(ms)** Average Precision***
EfficientDet-Lite0 4.4 37 25.69%
EfficientDet-Lite1 5.8 49 30.55%
EfficientDet-Lite2 7.2 69 33.97%
EfficientDet-Lite3 11.4 116 37.70%
EfficientDet-Lite4 19.9 260 41.96%

** Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.*

3.1、选择预训练模型

spec = model_spec.get('efficientdet_lite0')

此处在国内的服务器上是会提示超时报错终止,原因就是被墙了,所以要根据提示修改源码成镜像文件路径

3.2、修改源码

# 预训练模型配置文件
vim ~/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py

# 找到efficientdet_lite0_spec配置文件
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1',
)
# 把uri换一下
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/lite0/feature-vector/1.tar.gz',
)    

关键是替换uri,再重新执行spec = model_spec.get('efficientdet_lite0')

四、训练模型

model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Epoch 1/50
540/540 [==============================] - 253s 399ms/step - det_loss: 0.6041 - cls_loss: 0.3679 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.6678 - learning_rate: 0.0090 - gradient_norm: 4.1991 - val_det_loss: 1.2947 - val_cls_loss: 0.8470 - val_box_loss: 0.0090 - val_reg_l2_loss: 0.0645 - val_loss: 1.3592
Epoch 2/50
540/540 [==============================] - 214s 397ms/step - det_loss: 0.3937 - cls_loss: 0.2513 - box_loss: 0.0028 - reg_l2_loss: 0.0651 - loss: 0.4588 - learning_rate: 0.0100 - gradient_norm: 3.2312 - val_det_loss: 0.3262 - val_cls_loss: 0.2136 - val_box_loss: 0.0023 - val_reg_l2_loss: 0.0656 - val_loss: 0.3918
Epoch 3/50
540/540 [==============================] - 213s 394ms/step - det_loss: 0.3450 - cls_loss: 0.2250 - box_loss: 0.0024 - reg_l2_loss: 0.0660 - loss: 0.4110 - learning_rate: 0.0099 - gradient_norm: 2.8205 - val_det_loss: 0.2999 - val_cls_loss: 0.2096 - val_box_loss: 0.0018 - val_reg_l2_loss: 0.0664 - val_loss: 0.3663
。。。。。

评估模型

model.evaluate(test_data)

输出:

{'AP': 0.82879966,
 'AP50': 0.9893871,
 'AP75': 0.9637165,
 'APs': 0.50417614,
 'APm': 0.83946806,
 'APl': 0.8315978,
 'ARmax1': 0.7818135,
 'ARmax10': 0.8720247,
 'ARmax100': 0.87727976,
 'ARs': 0.7034483,
 'ARm': 0.89498526,
 'ARl': 0.87662005,
 'AP_/ball': 0.82879966}

五、导出tflite模型

model.export(export_dir='/root/xxx/tf')

会在/root/xxx/tf文件夹下生成model.tflite文件

评估模型:

model.evaluate_tflite('model.tflite', test_data)

输出

{'AP': 0.817586,
 'AP50': 0.98929125,
 'AP75': 0.95808136,
 'APs': 0.4901086,
 'APm': 0.8326331,
 'APl': 0.81800973,
 'ARmax1': 0.77594024,
 'ARmax10': 0.8460072,
 'ARmax100': 0.84688306,
 'ARs': 0.63793105,
 'ARm': 0.86342186,
 'ARl': 0.84720457,
 'AP_/ball': 0.817586}

可以看出导出tflite之后模型的识别度从0.82879966下降到了0.817586,也还算能接受

tflite模型测试:

# Imports
from tflite_support.task import vision
from tflite_support.task import core
from tflite_support.task import processor

# Initialization
base_options = core.BaseOptions(file_name='/root/xxx/tf/model.tflite')
detection_options = processor.DetectionOptions(max_results=2)
options = vision.ObjectDetectorOptions(base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)

# Alternatively, you can create an object detector in the following manner:
# detector = vision.ObjectDetector.create_from_file(model_path)

# Run inference
image = vision.TensorImage.create_from_file('/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpeg')
detection_result = detector.detect(image)

image = vision.TensorImage.create_from_file('/root/xxx/11.png')
detection_result = detector.detect(image)
print(detection_result)
资料

https://tensorflow.google.cn/lite/models/modify/model_maker/object_detection

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,444评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,421评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,036评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,363评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,460评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,502评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,511评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,280评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,736评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,014评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,190评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,848评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,531评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,159评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,411评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,067评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,078评论 2 352

推荐阅读更多精彩内容