ML.NET + PostgreSQL 实现一个以图搜图的功能

TODO:

涉及的理论较多,这里空白太小,写不下。
设计?我的设计大约的确就是 没有设计

先上结论:





需要准备什么?

  • ML.NET (处理模型输入输出,可以查看作者之前的文章)
  • PostgreSQL (数组、函数)
  • Onnx (可以去Github找找开源的模型)
  • 向量相关知识(线性代数)
  • pgvector(向量数据库,先进的思想和奇怪的手法)
  • React(界面展示,可选)
  • Bing 、GitHub(搜索上面的关键词)

核心代码

PG函数

  • 向量余弦相似度比较函数
CREATE EXTENSION IF NOT EXISTS plpython3u;

CREATE OR REPLACE FUNCTION cosineSimilarity(vector1 FLOAT[], vector2 FLOAT[]) RETURNS FLOAT AS $$
import math

norm1 = math.sqrt(sum(val * val for val in vector1))
norm2 = math.sqrt(sum(val * val for val in vector2))

dot_product = sum(val1 * val2 for val1, val2 in zip(vector1, vector2))

return dot_product / (norm1 * norm2)
$$ LANGUAGE plpython3u;

PG函数映射到EF

public static class CustomPostgreSQLFunctions
{
    [DbFunction("cosinesimilarity", "public")]
    public static float CosineSimilarity(float[] vector1, float[] vector2)
    {
        throw new NotSupportedException();
    }

    public static readonly MethodInfo CosineSimilarityMethodInfo =
      typeof(CustomPostgreSQLFunctions)
     .GetMethod("CosineSimilarity");
}

-----------------------华丽的分割线--------------------------

modelBuilder.HasDbFunction(CustomPostgreSQLFunctions.CosineSimilarityMethodInfo);

核心算法

using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Onnx;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using SixLabors.ImageSharp.Processing;
using System;
using System.Linq;
using System.Runtime.InteropServices;

namespace WTF.Services;

public class ImageFeatureService
{
    private static readonly SchemaDefinition _schemaDefinition = SchemaDefinition.Create(typeof(InputData));
    private static readonly MLContext _mlContext = new();
    private static readonly OnnxScoringEstimator _pipeline = _mlContext
                                              .Transforms
                                              .ApplyOnnxModel(
                                                 modelFile: "weights/resnet101-v1-7.onnx",
                                                 outputColumnNames: new[] { "resnetv18_dense0_fwd" },
                                                 inputColumnNames: new[] { "data" }
                                               );

    public float[] Predict(Image<Rgb24> image)
    {
        image.Mutate(x => x.Resize(224, 224));

        var chw = ImageToNCHW(image);

        var dataFrame = _mlContext.Data.LoadFromEnumerable(new[] { new InputData { Images = chw, } }, _schemaDefinition);

        var predictions = _pipeline.Fit(dataFrame).Transform(dataFrame);

        var predictedData = _mlContext.Data.CreateEnumerable<OutputData>(predictions, false);

        return predictedData.First().Outputs;
    }

    static float[] ImageToNCHW(Image<Rgb24> image)
    {
        var pxData = new float[3, image.Height, image.Width];

        for (int h = 0; h < image.Height; h++)
        {
            for (int w = 0; w < image.Width; w++)
            {
                var px = image[w, h];

                pxData[0, h, w] = px.R / 255f;
                pxData[1, h, w] = px.G / 255f;
                pxData[2, h, w] = px.B / 255f;
            }
        }

        return MemoryMarshal.CreateReadOnlySpan(ref pxData[0, 0, 0], 3 * image.Height * image.Width).ToArray();
    }

    class OutputData
    {
        [ColumnName("resnetv18_dense0_fwd")]
        public float[] Outputs { get; set; } = Array.Empty<float>();
    }

    class InputData
    {
        [ColumnName("data")]
        [VectorType(1, 3, 224, 224)]
        public float[] Images { get; set; } = Array.Empty<float>();
    }
}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容