介绍
YOLO(You Only Look Once)是一种使用卷积神经网络进行目标检测的算法。YOLO 系列模型集成度很高、使用简单,是实际开发中常用的目标检测模型。但 YOLO 模型本身无法直接在 iOS 中使用,因此本文将讲解如何使用 YOLO 训练模型,并将训练好的模型转化为 Core ML 模型,然后在项目中使用。
下载YOLO模型
- 在 huggingface 或者 Ultralytics 网站下载 YOLOv8 模型。
- 根据需要下载不同精度的模型,共有 5 种不同精度的模型。
注意:由于是在端侧使用,因此本文以
yolov8n.pt
为例进行讲解。
训练YOLO模型
- 准备自定义目标检测数据集。
- 打开终端,使用如下命令训练模型。
yolo task=detect mode=train model=yolov8n.pt data=poker/data.yaml epochs=3 imgsz=640
- 训练完成之后得到一个新的模型文件,它才是最终需要转换的模型文件。
转换为Core ML
- 由于训练完成的模型文件无法直接使用,因此需要进一步将其转换为 Apple 官方的支持的 Core ML 模型。
from ultralytics import YOLO
model = YOLO(f"xxx.pt")
# 1. 导出为新的mlpackage格式
model.export(format="coreml", nms=True, imgsz=[640, 640])
# 2. 导出为老的mlmodel格式
model.export(format="mlmodel", nms=True, imgsz=[640, 640])
- 转换完成之后得到一个 Core ML 模型文件,它才是 iOS 项目中最终需要的模型文件。
模型测试
在项目中使用中之前,可以使用 Create ML 进行模型测试。双击打开转换好的模型文件,使用验证数据集进行验证,并查看效果。
开发使用
通过测试之后,就可以在项目中使用该模型,步骤如下:
- 将模型文件拷贝到项目工程中。
- 使用
Vision
框架对模型初始化。 - 创建
VNCoreMLRequest
并指定completionHandler
回调处理。 - 创建
VNImageRequestHandler
,传入目标照片或者通过摄像头捕获需要检测的目标。 - 检测到目标之后,通过
VNRecognizedObjectObservation
获取目标检测的内容与位置信息。
核心代码
import AVFoundation
import UIKit
import Vision
class ViewController: UIViewController {
// 设置模型
func setupModels() {
guard let modelURL = Bundle.main.url(forResource: "Xxx", withExtension: "mlmodelc") else {
return
}
do {
let visionModel = try VNCoreMLModel(for: MLModel(contentsOf: modelURL))
let objectRecognition = VNCoreMLRequest(model: visionModel) { request, _ in
DispatchQueue.main.async {
if let results = request.results {
self.drawVisionRequestResults(results)
}
}
}
requests = [objectRecognition]
} catch {
print("Model loading went wrong: \(error)")
}
}
// 识别处理
func drawVisionRequestResults(_ results: [Any]) {
for observation in results where observation is VNRecognizedObjectObservation {
guard let objectObservation = observation as? VNRecognizedObjectObservation else {
continue
}
let topLabelObservation = objectObservation.labels[0]
let objectBounds = VNImageRectForNormalizedRect(objectObservation.boundingBox, Int(bufferSize.width), Int(bufferSize.height))
print("置信度:", topLabelObservation.confidence)
print("内容:", topLabelObservation.identifier)
print("边框", objectBounds)
}
}
func exifOrientationFromDeviceOrientation() -> CGImagePropertyOrientation {
let curDeviceOrientation = UIDevice.current.orientation
let exifOrientation: CGImagePropertyOrientation
switch curDeviceOrientation {
case UIDeviceOrientation.portraitUpsideDown:
exifOrientation = .left
case UIDeviceOrientation.landscapeLeft:
exifOrientation = .upMirrored
case UIDeviceOrientation.landscapeRight:
exifOrientation = .down
case UIDeviceOrientation.portrait:
exifOrientation = .up
default:
exifOrientation = .up
}
return exifOrientation
}
}
extension ViewController: AVCaptureVideoDataOutputSampleBufferDelegate {
// 摄像头捕获后的代理方法
func captureOutput(_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) {
guard let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer) else {
return
}
let exifOrientation = exifOrientationFromDeviceOrientation()
let imageRequestHandler = VNImageRequestHandler(cvPixelBuffer: pixelBuffer, orientation: exifOrientation, options: [:])
do {
try imageRequestHandler.perform(requests)
} catch {
print(error)
}
}
}