1.软件准备
1.1 下载Anaconda
Anaconda(大蟒蛇),一个开源的Python发行版本,其包含了conda、Python等180多个科学包及其依赖项,其中的conda可以提供虚拟环境管理,构建AI环境特别方便。
直接官网下载 https://www.anaconda.com/download (需要梯子),进入后直接输入邮箱,官网会往邮箱发送一个下载链接
可以直接通过百度网盘下载
链接: https://pan.baidu.com/s/1KTbeurKl-mZUq14mF1LgKA?pwd=6qs7 提取码: 6qs7
1.2下载GOT-OCR2.0
去github下载 https://github.com/Ucas-HaoranWei/GOT-OCR2.0(需要梯子)
1.3下载GOT-OCR-GPU模型
- BaiduYun 验证码: OCR2
1.4下载cuda 11.7
wget https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run
需要下载其他版本的,进入连接选择
https://developer.nvidia.com/cuda-toolkit-archive
2.安装cuda
sudo chmod +x cuda_11.7.1_515.65.01_linux.run
sudo ./cuda_11.7.1_515.65.01_linux.run
进入下面的页面,填 accept
进入组件安装页面,上下键来移动光标,回车键为选中和不选中,由于机器已经安装了驱动(driver)我这里仅选中Toolkit,有需要的可以全选
光标移动到
Install
按回车键开始安装设置
bashrc
vim ~/.bashrc
##将一下文本写入文件最后
export PATH=/usr/local/cuda-11.7/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64:$LD_LIBRARY_PATH
##生效bachrc文件
source ~/.bashrc
##查看nvcc版本
nvcc -V
看到如下内容就是安装正常
3.安装Anaconda(使用的是cpu相关的截图,目录啥的按照自己的情况改就行)
将安装文件上传到服务器,并cd到对应目录
##执行命令,会进入安装页面,需要阅读好长的协议一直按回车就行只到出现下图
bash Anaconda3-2024.10-1-Linux-x86_64.sh
确认安装目录,最好手动输入一下找个空间大的目录
接下载来就是漫长的等待,只到提示这些内容输入yes就安装完成了
##使配置文件生效
source ~/.bashrc
##检查是否安装完成,有版本号展示就行
conda -V
4.构建虚拟环境以及GOT-OCR
创建虚拟环境
##创建一个名为 mygot 且python版本为3.10的虚拟环境
conda create -n mygot python=3.10 -y
##激活虚拟环境
conda activate mygot
##激活后命令行前方会显示虚拟环境的名称,类似下面的样子
(mygot) [root@sybj-int-83 ollamaModels]#
安装GOT-COR
##切到安装目录
cd /opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/
##进入目录执行安装
pip install -e .
##安装albumentations,否则会报下图的错误
pip install albumentations==1.4.20
看到下图即是安装成功
由于没有找到合适的flash-attn版本,就不安装官放说的那俩组件了,这就算安装完成了,直接测试识别效果
python3 /opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py --model-name /opt/ollamaModels/GOT_weights/ --image-file /opt/ollamaModels/WechatIMG436.jpg --type format
报版本cuda版本太低,检查了一下torch的版本太高了默认安装了2.6.0
,应该安装2.0.1
执行下面的命令,将组件都拉齐版本
pip install torch==2.0.1
pip install torchvision==0.15.2
pip install transformers==4.37.2
pip install tiktoken==0.6.0
pip install verovio==4.3.1
pip install accelerate==0.28.0
继续测试,报gpu内存不够了
nvidia-smi
看一下哪些进程占着gpu呢,我这里是ollama跑这个一个deepseek的8b模型,直接杀了再次测试,即可看到结果
我这里nvidia的版本比较低,默认情况下run_ocr_2.0.py 是有一些方法是不支持的,执行也会报错,我把改过的文件贴到这里,修改过的部分已标注
----修改
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import string
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
translation_table = str.maketrans(punctuation_dict)
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
#改为float16----修改
model.to(device='cuda', dtype=torch.float16)
# TODO vary old codes, NEED del
image_processor = BlipImageEvalProcessor(image_size=1024)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
image = load_image(args.image_file)
w, h = image.size
# print(image.size)
if args.type == 'format':
qs = 'OCR with format: '
else:
qs = 'OCR: '
if args.box:
bbox = eval(args.box)
if len(bbox) == 2:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
if len(bbox) == 4:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
bbox[2] = int(bbox[2]/w*1000)
bbox[3] = int(bbox[3]/h*1000)
if args.type == 'format':
qs = str(bbox) + ' ' + 'OCR with format: '
else:
qs = str(bbox) + ' ' + 'OCR: '
if args.color:
if args.type == 'format':
qs = '[' + args.color + ']' + ' ' + 'OCR with format: '
else:
qs = '[' + args.color + ']' + ' ' + 'OCR: '
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
inputs = tokenizer([prompt])
# vary old codes, no use
image_1 = image.copy()
image_tensor = image_processor(image)
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
#改为float16----修改
with torch.autocast("cuda", dtype=torch.float16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=False,
num_beams = 1,
no_repeat_ngram_size = 20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
if args.render:
print('==============rendering===============')
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
if '**kern' in outputs:
import verovio
from cairosvg import svg2png
import cv2
import numpy as np
tk = verovio.toolkit()
tk.loadData(outputs)
tk.setOptions({"pageWidth": 2100, "footer": 'none',
'barLineWidth': 0.5, 'beamMaxSlope': 15,
'staffLineWidth': 0.2, 'spacingStaff': 6})
tk.getPageCount()
svg = tk.renderToSVG()
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
#改为绝对路径----修改
svg_to_html(svg, "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html")
if args.type == 'format' and '**kern' not in outputs:
if '\\begin{tikzpicture}' not in outputs:
#改为绝对路径----修改
html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "content-mmd-to-html.html"
html_path_2 = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html"
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
else:
#改为绝对路径----修改
html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "tikz.html"
html_path_2 = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html"
outputs = outputs.translate(translation_table)
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
if out:
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
while out[-1] == ' ':
out = out[:-1]
if out is None:
break
if out:
if out[-1] != ';':
gt += out[:-1] + ';\n'
else:
gt += out + '\n'
else:
gt += out + '\n'
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--box", type=str, default= '')
parser.add_argument("--color", type=str, default= '')
parser.add_argument("--render", action='store_true')
args = parser.parse_args()
eval_model(args)
5.API服务
将以下内容复制到GPU_API.py
文件中,切换虚拟环境,启动api接口服务python3 GPU_API.py
,如果有报错根据提示安装相应组件就行,后台启动命令为nohup conda run -n mygot python GPU_API.py > got-ocr-api.log 2>&1
注意需要切换到脚本所在的目录执行,另脚本中涉及到的绝对路径替换成自己的服务器的路径
。返回的格式化文本是mathpix markdown。
from flask import Flask, request, jsonify, send_from_directory
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import requests
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import time
import random
import logging
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler
import uuid
app = Flask(__name__)
# Set up logging
log_dir = "/opt/ollamaModels/logs/"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "got_ocr.log")
handler = TimedRotatingFileHandler(log_file, when='midnight', interval=1, backupCount=30)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
translation_table = str.maketrans(punctuation_dict)
# Load model and other resources only once when the server starts
disable_torch_init()
model_name = os.path.expanduser("/opt/ollamaModels/GOT_weights/") # Update with your actual model path
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
model.to(device='cuda', dtype=torch.float16)
image_processor = BlipImageEvalProcessor(image_size=1024)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
conv_mode = "mpt"
def load_image(file_storage):
image = Image.open(file_storage.stream).convert('RGB')
return image
def generate_unique_filename(extension=".html"):
timestamp = int(time.time())
random_number = random.randint(1000, 9999)
filename = f"{timestamp}_{random_number}{extension}"
return filename
@app.before_request
def log_request_info():
if request.endpoint == 'ocr':
trace_id = request.headers.get('X-Trace-ID') or str(uuid.uuid4())
request.environ['trace_id'] = trace_id
logger.info(f"Request: {request.method} {request.url} TraceID: {trace_id}")
if request.data:
logger.info(f"Request Data: {request.data} TraceID: {trace_id}")
# Log form data
form_data = request.form.to_dict()
if form_data:
logger.info(f"Form Data: {form_data} TraceID: {trace_id}")
# Log files
files = request.files.to_dict()
if files:
file_names = [file.filename for file in files.values()]
logger.info(f"Files Uploaded: {file_names} TraceID: {trace_id}")
@app.after_request
def log_response_info(response):
if request.endpoint == 'ocr':
trace_id = request.environ.get('trace_id')
logger.info(f"Response Status: {response.status} TraceID: {trace_id}, Response Data: {response.get_data(as_text=True)}")
return response
@app.route('/ocr', methods=['POST'])
def ocr():
if 'image_file' not in request.files:
return jsonify({"error": "No file part"}), 400
file_storage = request.files['image_file']
if file_storage.filename == '':
return jsonify({"error": "No selected file"}), 400
type_ = request.form.get('type')
box = request.form.get('box', '')
color = request.form.get('color', '')
render = request.form.get('render', '').lower() in ['true', '1']
image = load_image(file_storage)
w, h = image.size
if type_ == 'format':
qs = 'OCR with format: '
else:
qs = 'OCR: '
if box:
bbox = eval(box)
if len(bbox) == 2:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
if len(bbox) == 4:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
bbox[2] = int(bbox[2]/w*1000)
bbox[3] = int(bbox[3]/h*1000)
if type_ == 'format':
qs = str(bbox) + ' ' + 'OCR with format: '
else:
qs = str(bbox) + ' ' + 'OCR: '
if color:
if type_ == 'format':
qs = '[' + color + ']' + ' ' + 'OCR with format: '
else:
qs = '[' + color + ']' + ' ' + 'OCR: '
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
args_conv_mode = conv_mode
conv = conv_templates[args_conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image_1 = image.copy()
image_tensor = image_processor(image)
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.float16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=False,
num_beams=1,
no_repeat_ngram_size=20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
if render:
html_path_2 = os.path.join("/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results", generate_unique_filename())
if '**kern' in outputs:
import verovio
from cairosvg import svg2png
import cv2
import numpy as np
tk = verovio.toolkit()
tk.loadData(outputs)
tk.setOptions({"pageWidth": 2100, "footer": 'none',
'barLineWidth': 0.5, 'beamMaxSlope': 15,
'staffLineWidth': 0.2, 'spacingStaff': 6})
tk.getPageCount()
svg = tk.renderToSVG()
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
svg_to_html(svg, html_path_2)
if type_ == 'format' and '**kern' not in outputs:
if '\\begin{tikzpicture}' not in outputs:
html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "content-mmd-to-html.html"
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt = ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
else:
html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "tikz.html"
outputs = outputs.translate(translation_table)
outputs_list = outputs.split('\n')
gt = ''
for out in outputs_list:
if out:
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
while out[-1] == ' ':
out = out[:-1]
if out is None:
break
if out:
if out[-1] != ';':
gt += out[:-1] + ';\n'
else:
gt += out + '\n'
else:
gt += out + '\n'
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
base_url = "http://188.1.1.122:5000/results/"
full_html_path = base_url + os.path.basename(html_path_2)
return jsonify({"result": "Rendering complete", "outputs": outputs, "html_path": full_html_path})
else:
return jsonify({"result": "Processing complete", "outputs": outputs})
@app.route('/results/<path:filename>', methods=['GET'])
def serve_result(filename):
logger.info(f"Request: {request.method} {request.url}")
results_dir = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results"
return send_from_directory(results_dir, filename)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)
接口访问示例:
curl --location 'http://1.1.1.1:5000/ocr'
--form 'type="format"'
--form 'image_file=@"/Users/zhiaiyahong/Downloads/wodeshefenzheng_副本.jpg"'
--form 'render="true"'