站在巨人的肩膀上完成对象检测功能。ML.NET,一个为 dotnet er 量身打造的跨平台机器学习框架

我们的目标是【检测图片中的修沟,并把它的头用矩形框起来】,为了避免有小伙伴被劝退,先上结论,稳定一手军心。

参考文献

模型准备

  • 本文使用的模型是 yolov5 对象检测。
  • 训练过程比较简单,就跳过了。
  • 导出 onnx 文件。
python export.py --weights  best.pt  --img-size 640 --batch-size 1  --include onnx

ONNX模型的输出格式

  • Sigmoid(S型函数)。

    数学公式:f(x) = 1 / (1 + e^(-x);
    通俗点讲:1除以(1+ e的负x次方),结果无限接近(但不等于)0 和 1。
    代码表示:1 / (1 + Math.Pow(Math.E, -x))。

  • 张量(数组)形式的输出。
    没有固定的数学公式,能不能算出来全看运气(手动狗头)。

  • 我们代码中的使用到的模型输出是张量(数组)形式。

using

using Microsoft.ML;
using Microsoft.ML.Data;
using SixLabors.Fonts;
using SixLabors.ImageSharp.Drawing.Processing;
using System.Runtime.InteropServices;

代码模型


public class ObjectPrediction
{
    public float X { get; set; }
    public float Y { get; set; }
    public float Width { get; set; }
    public float Height { get; set; }
    public float Score { get; set; }
    public string Label { get; set; } = "";

    public RectangleF GetRectangle()
    {
        return new (X, Y, Width, Height);
    }

    public float GetArea()
    {
        return Width * Height;
    }
}

public class OutputData
{
    [ColumnName("output0")]
    public float[] Outputs { get; set; } = Array.Empty<float>();
}

public class InputData
{
    [ColumnName("images")]
    [VectorType(1 * 3 * 640 * 640)]
    public float[] Images { get; set; } = Array.Empty<float>();
}

图像转张量(NCHW)

因为n(即batch)是1,所以使用三维数组计算即可。
除以 255 是为了保证每个颜色值的范围在 0 到 1 之间。

static float[] ImageToNCHW(Image<Rgb24> image)
{
    var pxData = new float[3, image.Height, image.Width];

    for (int h = 0; h < image.Height; h++)
    {
        for (int w = 0; w < image.Width; w++)
        {
            var px = image[w, h];

            pxData[0, h, w] = px.R / 255f;
            pxData[1, h, w] = px.G / 255f;
            pxData[2, h, w] = px.B / 255f;
        }
    }

    return MemoryMarshal.CreateReadOnlySpan(ref pxData[0, 0, 0], 3 * image.Height * image.Width).ToArray();
}

模型输出张量转换 (NCHW还原)

static IEnumerable<ObjectPrediction> ParseDetect(
            float[] output,
            (int Width, int Height) modelImage,
            (int Width, int Height) originalImage,
            int dimensions,
            float confidence,
            params string[] labels 
            )
{
    // 记录原图和模型的比例
    var xGain = modelImage.Width / (float)originalImage.Width;
    var yGain = modelImage.Height / (float)originalImage.Height;

    var xPadding = (modelImage.Width - (originalImage.Width * xGain)) / 2;
    var yPadding = (modelImage.Height - (originalImage.Height * yGain)) / 2;

    for (int i = 0; i < output.Length / dimensions; i++)
    {
        // TODO: 
        // output[i * dimensions + 0]:目标的中心点横坐标  张量中的位置为:output[i, dimensions, 0]
        // output[i * dimensions + 1]:目标的中心点纵坐标  张量中的位置为:output[i, dimensions, 1]
        // output[i * dimensions + 2]:目标的宽度         张量中的位置为:output[i, dimensions, 2]
        // output[i * dimensions + 3]:目标的高度         张量中的位置为:output[i, dimensions, 3]
        // output[i * dimensions + 4]:目标的置信度       张量中的位置为:output[i, dimensions, 4]

        if (output[i * dimensions + 4] <= confidence) continue;

        for (int j = 5; j < dimensions; j++)
        {
            output[i * dimensions + j] *= output[i * dimensions + 4];
        }

        for (int k = 5; k < dimensions; k++)
        {
            // batch chanel height width (即nchw) 按照模型和原图的比例还原 
            var xMin = (output[i * dimensions + 0] - (output[i * dimensions + 2] / 2) - xPadding) / xGain; 
            var yMin = (output[i * dimensions + 1] - (output[i * dimensions + 3] / 2) - yPadding) / yGain;  

            var xMax = (output[i * dimensions + 0] + (output[i * dimensions + 2] / 2) - xPadding) / xGain;  
            var yMax = (output[i * dimensions + 1] + (output[i * dimensions + 3] / 2) - yPadding) / yGain;  

            // 防止结果超过图像边界
            xMin = Clamp(xMin, 0, originalImage.Width - 0); 
            yMin = Clamp(yMin, 0, originalImage.Height - 0); 
            xMax = Clamp(xMax, 0, originalImage.Width - 1);
            yMax = Clamp(yMax, 0, originalImage.Height - 1); 

            var prediction = new ObjectPrediction
            {
                Label = labels[k - 5],
                X = xMin,
                Y = yMin,
                Width = xMax - xMin,
                Height = yMax - yMin,
                Score = output[i * dimensions + k]
            };

            yield return prediction;
        }
    }

    static float Clamp(float value, float min, float max)
    {
        return (value < min) ? min : (value > max) ? max : value;
    }
}

根据重叠阈值从集合中删除重叠的项

笛卡尔积+回溯,比较简单。

static IEnumerable<ObjectPrediction> ClearOverlappingItems(IEnumerable<ObjectPrediction> predictions, float maxOverlap = 0.45f)
{
    var items = predictions.ToList();

    for (int i = 0; i < items.Count; i++)
    {
        var current = items[i];

        for (int j = i + 1; j < items.Count; j++)
        {
            var item = items[j];

            var intersection = RectangleF.Intersect(current.GetRectangle(), item.GetRectangle());

            var intersectionArea = intersection.Width * intersection.Height;
            var unionArea = current.GetArea() + item.GetArea() - intersectionArea;
            var overlapRatio = intersectionArea / unionArea;

            if (overlapRatio >= maxOverlap)
            {
                if (current.Score < item.Score)
                {
                    items.RemoveAt(i);
                    i--;
                    break;
                }
                else
                {
                    items.RemoveAt(j);
                    j--;
                }
            }
        }
    }

    return items;
}

上 ML.NET,一把梭哈

var mlContext = new MLContext();

var pipeline = mlContext.Transforms
    .ApplyOnnxModel(modelFile: "best.onnx", outputColumnNames: new[] { "output0" }, inputColumnNames: new[] { "images" });

using var image = Image.Load<Rgb24>("dog.jpg");

using var modelImage = image.Clone();
modelImage.Mutate(x => x.Resize(640, 640));

var schemaDefinition = SchemaDefinition.Create(typeof(InputData));

var bchw = ImageToNCHW(modelImage);

var dataFrame = mlContext.Data.LoadFromEnumerable(new[] { new InputData { Images = bchw, } }, schemaDefinition);

var predictions = pipeline.Fit(dataFrame).Transform(dataFrame);

var predictedData = mlContext.Data.CreateEnumerable<OutputData>(predictions, false);

foreach (var item in predictedData)
{
    var outputs = ParseDetect(
        output: item.Outputs,
        modelImage: (modelImage.Width, modelImage.Height),
        originalImage: (image.Width, image.Height),
        dimensions: 6,
        confidence: 0.3f,
        "dog"
            );

    var font = new Font(new FontCollection().Add("C:/Windows/Fonts/consola.ttf"), 16);

    foreach (var prediction in ClearOverlappingItems(outputs))
    {
        image.Mutate(a =>
        {
           a.DrawPolygon(
             new Pen(color: Color.Blue, 2),
               new(prediction.X, prediction.Y),
               new(prediction.X + prediction.Width, prediction.Y),
               new(prediction.X + prediction.Width, prediction.Y + prediction.Height),
               new(prediction.X, prediction.Y + prediction.Height)
            );
        });

        image.Mutate(a => a.DrawText($"{prediction.Label} ({prediction.Score})",
             font, Color.Blue, new PointF(prediction.X, prediction.Y)));
    }

   await image.SaveAsJpegAsync("result.jpg");
}

dotnet run ...

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容