近两年被称为AI的年度,这里就文生图模型Stable Diffusion做了个实践,记录下从0到1的实现过程。
实现的效果如下:
注:为方便描述,后续均以SD简写Stable Diffusion
环境
本文使用的环境如下:
- 系统:windows11
- 显卡:NVIDIA GeForce RTX 3060 Laptop GPU
- python3.10.6,用于运行sd
- Python 3.7,用于运行机器人
- Sublime Text 3
- go-cqhttp
- Yes酱
- stable-diffusion-webui
一、本地运行Stable Diffusion
1.1 下载源码
使用git下载stable-diffusion-webui源码:
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
如果机器上没有git,也可以直接在stable-diffusion-webui界面上下载源码解压即可:
1.2 安装python3.10.6
官方推荐使用Python3.10.6,目前试过python3.7和3.10.13均不行(会出现报错ModuleNotFoundError: No module named ‘importlib.metadata‘
),其他版本未尝试
windows的话直接下载安装包,双击安装即可:
1.3 更新显卡驱动
PS:驱动不更新的话运行项目可能出现错误:AssertionError: Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check
首先查看显卡型号:
- 键盘同时按下
win
+I
键,唤起系统设置 - 在系统设置页左侧搜索
设备管理器
,点击出现的下拉选项 - 在设备管理器中点击
显示适配器
,找到显卡型号:
显卡型号查看 - 由于本机为英伟达显卡,因此进到英伟达官方网站下载最新的驱动,在页面上输入对应型号,点击搜索后找到驱动程序点击下载:
英伟达驱动搜索
下载驱动 - 直接双击安装即可
1.4 下载基础模型(checkpoint)文件
通过百度网盘下载 提取码: 8vz6
基础模型文件的后缀为.ckpt
或.safetensor
,下载后将其放到步骤1下载的目录stable-diffusion-webui/models/stable-diffusion
中
1.5 运行sd
进到项目目录,双击webui-user.bat
,会自动下载依赖并加载模型
- 如遇到报错
Couldn't Install Torch
,使用以下命令给python换源后再次双击webui-user.bat
:
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
- 如遇到报错:
OSError: Can't load tokenizer for 'openai/clip-vit-large-patch14'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'openai/clip-vit-large-patch14' is the correct path to a directory containing all relevant files for a CLIPTokenizer tokenizer.
手动下载clip-vit-large-patch14包(提取码: rvyc)并在项目下新建文件夹openai
将其解压后放进去,然后再次双击webui-user.bat
:
1.6 SD成功运行
双击webui-user.bat
后,会出现一个cmd黑窗口,如果出现字样则说明运行成功:
此时可以打开浏览器,输入地址
http://127.0.0.1:7860/
即可访问SD,选取checkpoint,然后输入图片描述,点击右边的生成按钮即可生成图片(经过实验,其中DreamShaper_8_pruned.safetensors
模型生成效率最高,每张图仅耗时4s):PS:如果仅想学习如何本地运行SD,到这里就差不多了,后续行文主要是关于:1)机器人调用SD api生成图片;2)接入中文翻译使其能够识别中文的内容。
二、QQ机器人接入SD api生成图片
2.1 使用api模式启动SD
- 打开sd目录,复制一份
webui-user.bat
文件重命名为webui-user-api.bat
,设置COMMANDLINE_ARGS=--api
:
更改sd启动脚本 - 双击
webui-user-api.bat
重新启动sd,打开浏览器,输入地址http://127.0.0.1:7860/docs
,即进入了sd的api接口文档页面:
sd接口文档
2.2 脚本请求sd api
- 随意在本地新建一个脚本,使用代码请求sd 接口:
import requests
import base64
# Define the URL and the payload to send.
url = "http://127.0.0.1:7860"
payload = {
"prompt": "puppy dog",
"steps": 5
}
# Send said payload to said URL through the API.
response = requests.post(url=f'{url}/sdapi/v1/txt2img', json=payload)
r = response.json()
# Decode and save the image.
with open("output.png", 'wb') as f:
f.write(base64.b64decode(r['images'][0]))
- 运行脚本,可以在当前目录看到生成的图片
output.png
:
sd api试运行
不难发现,通过上述简单的方式请求得到的图片非常糊,和web界面生成的图片大相径庭,那么,如何能够使api生成的图片达到web界面的效果呢? - 打开sd的web界面,在空白页面点击鼠标右键,点击
检查
,选择网络
,然后在页面上点击按钮生成图片,抓取其接口数据:
sd接口抓取
通过对数据进行分析,不难发现其请求携带的数据和通过api请求携带的数据十分相似,因此,我们将在页面上测试感觉比较不错的接口数据字段值直接添加到脚本中:
payload = {
"prompt": "puppy dog",
"seed": 3646806933,
"subseed": 3965033073,
"subseed_strength": 0,
"width": 512,
"height": 512,
"sampler_name": "DPM++ 2M",
"cfg_scale": 7,
"steps": 20,
"batch_size": 1,
"restore_faces": False,
"face_restoration_model": None,
"sd_model_name": "DreamShaper_8_pruned",
"sd_model_hash": "879db523c3",
"sd_vae_name": None,
"sd_vae_hash": None,
"seed_resize_from_w": -1,
"seed_resize_from_h": -1,
"denoising_strength": 0.7
}
其中,经过测试,DreamShaper_8_pruned
模型是上文给出的模型中最快的模型,平均每张图生成仅需4s
2.3 集成翻译接口
sd仅认识英文,因此,如果想要使用中文生成图片,需要先进行翻译。这里使用的是腾讯翻译api,每个月有500w免费翻译额度,一般情况下够用了。
# -*- coding: utf-8 -*-
import hashlib
import hmac
import json
import sys
import time
from datetime import datetime
from http.client import HTTPSConnection
def sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
global_config = {
"tx_sid": "xxx",
"tx_skey": "xxx"
}
def tx_translate(msg, s="zh", t="en"):
secret_id = global_config["tx_sid"]
secret_key = global_config["tx_skey"]
token = ""
service = "tmt"
host = "tmt.tencentcloudapi.com"
region = "ap-beijing"
version = "2018-03-21"
action = "TextTranslate"
params = {
"SourceText": msg,
"Source": s,
"Target": t,
"ProjectId": 0
}
payload = json.dumps(params)
endpoint = "https://tmt.tencentcloudapi.com"
algorithm = "TC3-HMAC-SHA256"
timestamp = int(time.time())
date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d")
# ************* 步骤 1:拼接规范请求串 *************
http_request_method = "POST"
canonical_uri = "/"
canonical_querystring = ""
ct = "application/json; charset=utf-8"
canonical_headers = "content-type:%s\nhost:%s\nx-tc-action:%s\n" % (ct, host, action.lower())
signed_headers = "content-type;host;x-tc-action"
hashed_request_payload = hashlib.sha256(payload.encode("utf-8")).hexdigest()
canonical_request = (http_request_method + "\n" +
canonical_uri + "\n" +
canonical_querystring + "\n" +
canonical_headers + "\n" +
signed_headers + "\n" +
hashed_request_payload)
# ************* 步骤 2:拼接待签名字符串 *************
credential_scope = date + "/" + service + "/" + "tc3_request"
hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
string_to_sign = (algorithm + "\n" +
str(timestamp) + "\n" +
credential_scope + "\n" +
hashed_canonical_request)
# ************* 步骤 3:计算签名 *************
secret_date = sign(("TC3" + secret_key).encode("utf-8"), date)
secret_service = sign(secret_date, service)
secret_signing = sign(secret_service, "tc3_request")
signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
# ************* 步骤 4:拼接 Authorization *************
authorization = (algorithm + " " +
"Credential=" + secret_id + "/" + credential_scope + ", " +
"SignedHeaders=" + signed_headers + ", " +
"Signature=" + signature)
# ************* 步骤 5:构造并发起请求 *************
headers = {
"Authorization": authorization,
"Content-Type": "application/json; charset=utf-8",
"Host": host,
"X-TC-Action": action,
"X-TC-Timestamp": timestamp,
"X-TC-Version": version
}
if region:
headers["X-TC-Region"] = region
if token:
headers["X-TC-Token"] = token
try:
req = HTTPSConnection(host)
req.request("POST", "/", headers=headers, body=payload.encode("utf-8"))
resp = req.getresponse()
return json.loads(resp.read())['Response']['TargetText']
except Exception as err:
print(f"[tx_translate] translate [{msg}] from {s} to {t} failed: {err}")
return msg
2.4 集成sd api到QQ机器人中
机器人基于Yes酱进行添加模块编写。
该部分模块代码如下:
import os
import base64
import time
import requests
import hashlib
from random import choice
from common.common import logger
from common.config import global_config
from data.talk_data.base_talk import others_answer
from send_message.send_message import send_message
from .tx_translate import tx_translate
def has_chinese(string):
pattern = re.compile(u'[\u4e00-\u9fa5]+')
result = re.search(pattern, string)
return bool(result)
def truncate_string(pattern, text):
# 使用正则表达式匹配字符串
match = re.match(pattern, text)
if match:
# 如果匹配成功,返回剩余的字符串部分
return True, text[len(match.group()):]
else:
# 如果没有匹配,返回原始字符串
return False, text
def stable_diffusion(msg, sender, ws, group_id, diatype):
matched, processed_msg = truncate_string(r"^(文生图|生成图片|画图)", msg)
if not matched:
return [False, None]
try:
msg = processed_msg
msg_flag = f"\n【该回答由AI:Stable Diffusion提供】"
msg_md5 = hashlib.md5(msg.encode(encoding='utf-8')).hexdigest()
save_path = global_config.sd_path + f"{msg_md5}.png"
if os.path.exists(save_path):
local_img_url = "[CQ:image,file=file:///" + save_path + "]"
return [True, local_img_url+msg_flag]
if has_chinese(msg):
msg = tx_translate(msg)
payload = {
"prompt": msg,
"seed": 3646806933,
"subseed": 3965033073,
"subseed_strength": 0,
"width": 512,
"height": 512,
"sampler_name": "DPM++ 2M",
"cfg_scale": 7,
"steps": 20,
"batch_size": 1,
"restore_faces": False,
"face_restoration_model": None,
"sd_model_name": "DreamShaper_8_pruned",
"sd_model_hash": "879db523c3",
"sd_vae_name": None,
"sd_vae_hash": None,
"seed_resize_from_w": -1,
"seed_resize_from_h": -1,
"denoising_strength": 0.7
}
start_time = time.time()
# Send said payload to said URL through the API.
response = requests.post(url=f'{global_config.sd_url}/sdapi/v1/txt2img', json=payload)
r = response.json()
end_time = time.time()
logger.debug(f"[stable_diffusion] from msg: {msg}, generated img: {msg_md5}, cost: {end_time - start_time}s")
# Decode and save the image.
with open(save_path, 'wb') as f:
f.write(base64.b64decode(r['images'][0]))
local_img_url = "[CQ:image,file=file:///" + save_path + "]"
return [True, local_img_url+msg_flag]
except Exception as e:
# 其他错误
# print(e)
logger.error(f"in request 2 sd, an error occured: {e}")
return [False, "啊这,出了一点问题~"]