实现步骤如下:
- 通过大模型生成图片描述,然后写入数据库;
- 将图片描述+文本OCR+图片的标签合并成一句话,然后再通过m3e-base模型将文本变成向量化,写入milvus数据库;
- 再通过fastapi读取milvus实现api服务,返回文本描述最相似的图片。
第一步,下面是通过零一万物的api接口和CogAgent-Chat两种方式生成图片描述,CogAgent-Chat需要GPU(13G以上)
import torch
#from PIL import Image
#from modelscope import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
#import argparse
import os
import requests
import time
#from paddleocr import PaddleOCR
import utils
#import baidu_trans as baiduTrans
import openai
from openai import OpenAI
API_BASE = "https://api.lingyiwanwu.com/v1"
API_KEY = "token"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=API_KEY,
base_url=API_BASE
)
#parser = argparse.ArgumentParser()
#parser.add_argument("--quant", choices=[4, 8], type=int, default=None, help='quantization bits')
#parser.add_argument("--from_pretrained", type=str, default="/root/autodl-fs/ZhipuAI/cogagent-chat",
# help='pretrained ckpt')
#parser.add_argument("--local_tokenizer", type=str, default="/root/autodl-fs/AI-ModelScope/vicuna-7b-v1___5",
# help='tokenizer path')
## parser.add_argument("--local_tokenizer", type=str, default="/root/autodl-fs/ZhipuAI/chatglm2-6b", help='tokenizer path')
#parser.add_argument("--fp16", action="store_true")
#parser.add_argument("--bf16", action="store_true")
#
#args, unknown = parser.parse_known_args()
#
#MODEL_PATH = args.from_pretrained
#TOKENIZER_PATH = args.local_tokenizer
#DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
#
#tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
#if args.bf16:
# torch_type = torch.bfloat16
#else:
# torch_type = torch.float16
#
#print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
#
## quantization configuration for NF4 (4 bits)
#bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
#)
#
#if args.quant:
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH,
# torch_dtype=torch_type,
# low_cpu_mem_usage=True,
# quantization_config=bnb_config,
# trust_remote_code=True
# ).eval()
#else:
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH,
# torch_dtype=torch_type,
# low_cpu_mem_usage=True,
# load_in_4bit=args.quant is not None,
# trust_remote_code=True
# ).to(DEVICE).eval()
#
#ocr = PaddleOCR(use_angle_cls=True, lang="ch")
conn = utils.db_connect("online")
sql = "select * from img_source where status=1"
res = conn.select(sql)
for v in res:
image_id = v['id']
print(image_id)
if v['path'] is None:
continue
splitArr = v['path'].split(".")
cdnUrl = "https://cdn.vcbeat.top/" + v['path']
# print(cdnUrl)
# file = "image/" + str(image_id) + "." + splitArr[1]
# print(file)
# file_input = requests.get(cdnUrl)
# os.makedirs('image', exist_ok=True)
# with open(file, 'wb') as f:
# f.write(file_input.content)
# f.close()
#
# image = Image.open(file).convert('RGB')
# history = []
# query = "Provide a detailed description of what is depicted in this picture"
# input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
# inputs = {
# 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
# 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
# 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
# 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
# }
# if 'cross_images' in input_by_model and input_by_model['cross_images']:
# inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
#
# # add any transformers params here.
# gen_kwargs = {"max_length": 2048,
# "temperature": 0.9,
# "do_sample": False}
# with torch.no_grad():
#outputs = model.generate(**inputs, **gen_kwargs)
#outputs = outputs[:, inputs['input_ids'].shape[1]:]
#response = tokenizer.decode(outputs[0])
#response = response.split("</s>")[0]
#print("\nCog:", response)
#content = baiduTrans.common_translate(response)
#
#result = ocr.ocr(file, cls=True)
#texts = []
#for idx in range(len(result)):
# res = result[idx]
# if res is not None:
# for line in res:
# texts.append(line[1][0])
#ocr_desc = ",".join(texts)
content = [
{"type":"image_url","image_url":{"url":cdnUrl}},
{"type":"text","text":"这张图片描述了什么内容,如果图片上面有文字,请同时回复文字的内容"}
]
completion = client.chat.completions.create(
model="yi-vl-plus",
messages=[{"role": "user", "content": content}]
)
content = completion.choices[0].message.content
print(content)
fields = {
"ai_desc": content,
# "ocr_desc":ocr_desc
}
utils.update_data(conn,"img_source", image_id, fields)
time.sleep(30)
总结:注释部分是通过CogAgent-Chat大模型生成的图片描述;对比发现零一万物的图片描述更加丰富(但是也有幻觉的数据);
第二步,写入mivus数据库脚本
import time
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
from sentence_transformers import SentenceTransformer
import utils
from milvus.milvus_ann_search import search, insert, query
def run():
conn_online = utils.db_connect("online")
tokenizer = SentenceTransformer('./m3e-base')
utils.milvus_connect("milvus")
sql = "select * from img_source where status=1 and path IS NOT NULL"
res = conn_online.select(sql)
for v in res:
print(v['id'])
content = []
if v['title'] is not None:
content.append(v['title'])
if v['ocr_desc'] != "":
content.append(v['ocr_desc'])
if v['ai_desc'] != "":
content.append(v['ai_desc'])
desc = ",".join(content)
all_embeddings = tokenizer.encode([desc])
collection_name = "img_source_new"
# app_mi_id = 1
now = time.time()
create_time = int(now)
data = [
[all_embeddings[0].tolist()],
[v['id']],
[create_time]
]
mr = insert(collection_name, data)
print(mr)
run()
第三步,提供API服务,返回最相似的20张图片
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import utils
from milvus.milvus_ann_search import search, insert, query
conn_online = utils.db_connect("online")
tokenizer = SentenceTransformer('./m3e-base')
utils.milvus_connect("milvus")
class SearchData(BaseModel):
search: str = ""
app = FastAPI()
# 图片搜索
@app.post("/get_images/")
async def get_images(topic: SearchData):
search_text = topic.search
print("search=" + search_text)
collection_name = "img_source_new"
all_embeddings = tokenizer.encode([search_text])
anns_field = "desc_vec"
results = search(collection_name, anns_field, all_embeddings, 20, None, ['id', 'img_id'])
hit_ids = []
hit_dict = {}
arr_id = []
for hits in results:
for hit in hits:
hit_ids.append(hit.id)
distance = hit.distance
hit_dict[hit.id] = distance
img_id = hit.entity.get('img_id')
# print(img_id)
arr_id.append((img_id, distance))
print("img_id={}\tdistance={}".format(img_id, distance))
return {"msg": "success", "code": 200, "response": arr_id}
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8100)
测试结果如下,也可以通过distance来控制返回图片的相似度,距离太远的可以不返回。