1. 概述
训练一个服饰的分类器,这里采用yolo11x
的神经网络架构训练。
2. 数据集准备
2.1. 公开数据集
DeepFashion2
https://github.com/switchablenorms/DeepFashion2
https://github.com/switchablenorms/DeepFashion2/blob/master/evaluation/deepfashion2_to_coco.py
Coco
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco128.yaml
2.2. 数据标注工具
roboflow数据标注工具,可以export yolo格式数据集
https://roboflow.com/Refine high-quality datasets and visual AI models
https://github.com/voxel51/fiftyone
https://voxel51.com/
本文采用一个私有数据集,一共7w+图片,按照
6:2:2
分为训练数据集
、验证数据集
和测试数据集
。
3. 训练
3.1. 训练代码
from ultralytics import YOLO
from ultralytics import settings
import wandb
from wandb.integration.ultralytics import add_wandb_callback
settings.update(dict(
datasets_dir="/data/yolo/dataset/yolo",
weights_dir="/data/yolo/dataset/yolo/weights",
runs_dir="/data/yolo/dataset/yolo/runs",
wandb=True,
comet=False,
clearml=False,
tensorboard=False,
))
project = "fashion-detector-19cls"
# Load a model
model = YOLO("yolo11x.pt")
# model = YOLO("fashion-detector-19cls/yolo11x/weights/last.pt")
add_wandb_callback(model, enable_model_checkpointing=True)
# Train the model
data_config = "dataset/yolo/haier_fashion.yaml"
results = model.train(
data=data_config, epochs=100,
imgsz=640, device=[2,3,4], batch=60,
project=project, name="yolo11x",
# resume=True
)
注意:上述代码是预训练,如果做继续训练,只需要
加载续训练model
和resume=True
即可。
3.2. WandB训练过程指标分析
Epoch=100的训练结果
Metric | Value |
---|---|
lr/pg0 | 0.0002 |
lr/pg1 | 0.0002 |
lr/pg2 | 0.0002 |
metrics/mAP50(B) | 0.83648 |
metrics/mAP50-95(B) | 0.6932 |
metrics/precision(B) | 0.81303 |
metrics/recall(B) | 0.7881 |
train/box_loss | 0.45027 |
train/cls_loss | 0.23466 |
train/dfl_loss | 0.96325 |
val/box_loss | 0.65851 |
val/cls_loss | 0.49071 |
val/dfl_loss | 1.08395 |
4. 测试
- 服饰类别一共19个
Dress Coat Top Jacket Skirt Suspender Short Pant Swim-Suit Shoe Cap Glass Watch Bag Belt Glove Scarf Jewelry Non-Fashion
- 对照Test数据集与预测结果
注意:正常情况下,需要进行
模型评估
,通过mAP、Precission、Recall等指标评价模型的性能,尤其是泛化能力。这里先不做了。
5. 脚本说明
5.1. 准备数据
python -m "scripts.data_prepare"
5.2. 训练
python train.py
5.3. 预测
python predict.py
or
python check_data.py