Ubuntu 18.04部署GOT-OCR GPU版

1.软件准备

1.1 下载Anaconda

Anaconda(大蟒蛇),一个开源的Python发行版本,其包含了conda、Python等180多个科学包及其依赖项,其中的conda可以提供虚拟环境管理,构建AI环境特别方便。
直接官网下载 https://www.anaconda.com/download (需要梯子),进入后直接输入邮箱,官网会往邮箱发送一个下载链接

官网截图

邮件截图

可以直接通过百度网盘下载
链接: https://pan.baidu.com/s/1KTbeurKl-mZUq14mF1LgKA?pwd=6qs7 提取码: 6qs7

1.2下载GOT-OCR2.0

去github下载 https://github.com/Ucas-HaoranWei/GOT-OCR2.0(需要梯子)

1.3下载GOT-OCR-GPU模型

1.4下载cuda 11.7

wget https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run
需要下载其他版本的,进入连接选择
https://developer.nvidia.com/cuda-toolkit-archive

image.png

2.安装cuda

sudo chmod +x cuda_11.7.1_515.65.01_linux.run
sudo ./cuda_11.7.1_515.65.01_linux.run

进入下面的页面,填 accept

image.png

进入组件安装页面,上下键来移动光标,回车键为选中和不选中,由于机器已经安装了驱动(driver)我这里仅选中Toolkit,有需要的可以全选
image.png

光标移动到Install按回车键开始安装
image.png

设置bashrc

vim ~/.bashrc
##将一下文本写入文件最后
export PATH=/usr/local/cuda-11.7/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64:$LD_LIBRARY_PATH
##生效bachrc文件
source ~/.bashrc
##查看nvcc版本
nvcc -V

看到如下内容就是安装正常


image.png

3.安装Anaconda(使用的是cpu相关的截图,目录啥的按照自己的情况改就行)

将安装文件上传到服务器,并cd到对应目录

##执行命令,会进入安装页面,需要阅读好长的协议一直按回车就行只到出现下图
bash Anaconda3-2024.10-1-Linux-x86_64.sh
输入yes就行

确认安装目录,最好手动输入一下找个空间大的目录


image.png

接下载来就是漫长的等待,只到提示这些内容输入yes就安装完成了


image.png
##使配置文件生效
source ~/.bashrc
##检查是否安装完成,有版本号展示就行
conda -V

4.构建虚拟环境以及GOT-OCR

创建虚拟环境

##创建一个名为 mygot 且python版本为3.10的虚拟环境
conda create -n mygot python=3.10 -y
##激活虚拟环境
conda activate mygot
##激活后命令行前方会显示虚拟环境的名称,类似下面的样子
(mygot) [root@sybj-int-83 ollamaModels]# 

安装GOT-COR

##切到安装目录
cd /opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/
##进入目录执行安装
pip install -e .
##安装albumentations,否则会报下图的错误
pip install albumentations==1.4.20
image.png

看到下图即是安装成功


image.png

由于没有找到合适的flash-attn版本,就不安装官放说的那俩组件了,这就算安装完成了,直接测试识别效果

python3 /opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py  --model-name  /opt/ollamaModels/GOT_weights/ --image-file  /opt/ollamaModels/WechatIMG436.jpg  --type format

报版本cuda版本太低,检查了一下torch的版本太高了默认安装了2.6.0,应该安装2.0.1

image.png

执行下面的命令,将组件都拉齐版本

pip install torch==2.0.1            
pip install torchvision==0.15.2     
pip install transformers==4.37.2    
pip install tiktoken==0.6.0         
pip install verovio==4.3.1          
pip install accelerate==0.28.0      

继续测试,报gpu内存不够了

image

nvidia-smi看一下哪些进程占着gpu呢,我这里是ollama跑这个一个deepseek的8b模型,直接杀了
image.png

再次测试,即可看到结果
image.png

我这里nvidia的版本比较低,默认情况下run_ocr_2.0.py 是有一些方法是不支持的,执行也会报错,我把改过的文件贴到这里,修改过的部分已标注----修改

import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor

from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import string

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'

DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'


 
translation_table = str.maketrans(punctuation_dict)


def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def eval_model(args):
    # Model
    disable_torch_init()
    model_name = os.path.expanduser(args.model_name)

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()

    
    #改为float16----修改
    model.to(device='cuda',  dtype=torch.float16)


    # TODO vary old codes, NEED del 
    image_processor = BlipImageEvalProcessor(image_size=1024)

    image_processor_high =  BlipImageEvalProcessor(image_size=1024)

    use_im_start_end = True

    image_token_len = 256

    image = load_image(args.image_file)

    w, h = image.size
    # print(image.size)
    
    if args.type == 'format':
        qs = 'OCR with format: '
    else:
        qs = 'OCR: '

    if args.box:
        bbox = eval(args.box)
        if len(bbox) == 2:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
        if len(bbox) == 4:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
            bbox[2] = int(bbox[2]/w*1000)
            bbox[3] = int(bbox[3]/h*1000)
        if args.type == 'format':
            qs = str(bbox) + ' ' + 'OCR with format: '
        else:
            qs = str(bbox) + ' ' + 'OCR: '

    if args.color:
        if args.type == 'format':
            qs = '[' + args.color + ']' + ' ' + 'OCR with format: '
        else:
            qs = '[' + args.color + ']' + ' ' + 'OCR: '

    if use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs 
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs



    conv_mode = "mpt"
    args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    print(prompt)


    inputs = tokenizer([prompt])


    # vary old codes, no use
    image_1 = image.copy()
    image_tensor = image_processor(image)


    image_tensor_1 = image_processor_high(image_1)


    input_ids = torch.as_tensor(inputs.input_ids).cuda()

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    #改为float16----修改
    with torch.autocast("cuda", dtype=torch.float16):
        output_ids = model.generate(
            input_ids,
            images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
            do_sample=False,
            num_beams = 1,
            no_repeat_ngram_size = 20,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=[stopping_criteria]
            )
        

        if args.render:
            print('==============rendering===============')

            outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            
            if outputs.endswith(stop_str):
                outputs = outputs[:-len(stop_str)]
            outputs = outputs.strip()

            if '**kern' in outputs:
                import verovio
                from cairosvg import svg2png
                import cv2
                import numpy as np
                tk = verovio.toolkit()
                tk.loadData(outputs)
                tk.setOptions({"pageWidth": 2100, "footer": 'none',
               'barLineWidth': 0.5, 'beamMaxSlope': 15,
               'staffLineWidth': 0.2, 'spacingStaff': 6})
                tk.getPageCount()
                svg = tk.renderToSVG()
                svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
                #改为绝对路径----修改
                svg_to_html(svg, "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html")

            if args.type == 'format' and '**kern' not in outputs:

                
                if  '\\begin{tikzpicture}' not in outputs:
                    #改为绝对路径----修改
                    html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "content-mmd-to-html.html"
                    html_path_2 = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html"
                    right_num = outputs.count('\\right')
                    left_num = outputs.count('\left')

                    if right_num != left_num:
                        outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')


                    outputs = outputs.replace('"', '``').replace('$', '')

                    outputs_list = outputs.split('\n')
                    gt= ''
                    for out in outputs_list:
                        gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' 
                    
                    gt = gt[:-2]

                    with open(html_path, 'r') as web_f:
                        lines = web_f.read()
                        lines = lines.split("const text =")
                        new_web = lines[0] + 'const text ='  + gt  + lines[1]
                else:
                    #改为绝对路径----修改
                    html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "tikz.html"
                    html_path_2 = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results/demo.html"
                    outputs = outputs.translate(translation_table)
                    outputs_list = outputs.split('\n')
                    gt= ''
                    for out in outputs_list:
                        if out:
                            if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
                                while out[-1] == ' ':
                                    out = out[:-1]
                                    if out is None:
                                        break
    
                                if out:
                                    if out[-1] != ';':
                                        gt += out[:-1] + ';\n'
                                    else:
                                        gt += out + '\n'
                            else:
                                gt += out + '\n'


                    with open(html_path, 'r') as web_f:
                        lines = web_f.read()
                        lines = lines.split("const text =")
                        new_web = lines[0] + gt + lines[1]

                with open(html_path_2, 'w') as web_f_new:
                    web_f_new.write(new_web)





if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--type", type=str, required=True)
    parser.add_argument("--box", type=str, default= '')
    parser.add_argument("--color", type=str, default= '')
    parser.add_argument("--render", action='store_true')
    args = parser.parse_args()

    eval_model(args)

5.API服务

将以下内容复制到GPU_API.py文件中,切换虚拟环境,启动api接口服务python3 GPU_API.py,如果有报错根据提示安装相应组件就行,后台启动命令为nohup conda run -n mygot python GPU_API.py > got-ocr-api.log 2>&1注意需要切换到脚本所在的目录执行,另脚本中涉及到的绝对路径替换成自己的服务器的路径。返回的格式化文本是mathpix markdown。

from flask import Flask, request, jsonify, send_from_directory
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import requests
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import time
import random
import logging
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler
import uuid

app = Flask(__name__)

# Set up logging
log_dir = "/opt/ollamaModels/logs/"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "got_ocr.log")

handler = TimedRotatingFileHandler(log_file, when='midnight', interval=1, backupCount=30)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'

translation_table = str.maketrans(punctuation_dict)

# Load model and other resources only once when the server starts
disable_torch_init()
model_name = os.path.expanduser("/opt/ollamaModels/GOT_weights/")  # Update with your actual model path

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
model.to(device='cuda', dtype=torch.float16)

image_processor = BlipImageEvalProcessor(image_size=1024)
image_processor_high = BlipImageEvalProcessor(image_size=1024)

use_im_start_end = True
image_token_len = 256

conv_mode = "mpt"

def load_image(file_storage):
    image = Image.open(file_storage.stream).convert('RGB')
    return image

def generate_unique_filename(extension=".html"):
    timestamp = int(time.time())
    random_number = random.randint(1000, 9999)
    filename = f"{timestamp}_{random_number}{extension}"
    return filename

@app.before_request
def log_request_info():
    if request.endpoint == 'ocr':
        trace_id = request.headers.get('X-Trace-ID') or str(uuid.uuid4())
        request.environ['trace_id'] = trace_id
        
        logger.info(f"Request: {request.method} {request.url} TraceID: {trace_id}")
        if request.data:
            logger.info(f"Request Data: {request.data} TraceID: {trace_id}")
        
        # Log form data
        form_data = request.form.to_dict()
        if form_data:
            logger.info(f"Form Data: {form_data} TraceID: {trace_id}")

        # Log files
        files = request.files.to_dict()
        if files:
            file_names = [file.filename for file in files.values()]
            logger.info(f"Files Uploaded: {file_names} TraceID: {trace_id}")

@app.after_request
def log_response_info(response):
    if request.endpoint == 'ocr':
        trace_id = request.environ.get('trace_id')
        logger.info(f"Response Status: {response.status} TraceID: {trace_id}, Response Data: {response.get_data(as_text=True)}")
    return response

@app.route('/ocr', methods=['POST'])
def ocr():
    if 'image_file' not in request.files:
        return jsonify({"error": "No file part"}), 400

    file_storage = request.files['image_file']
    if file_storage.filename == '':
        return jsonify({"error": "No selected file"}), 400

    type_ = request.form.get('type')
    box = request.form.get('box', '')
    color = request.form.get('color', '')
    render = request.form.get('render', '').lower() in ['true', '1']

    image = load_image(file_storage)
    w, h = image.size

    if type_ == 'format':
        qs = 'OCR with format: '
    else:
        qs = 'OCR: '

    if box:
        bbox = eval(box)
        if len(bbox) == 2:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
        if len(bbox) == 4:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
            bbox[2] = int(bbox[2]/w*1000)
            bbox[3] = int(bbox[3]/h*1000)
        if type_ == 'format':
            qs = str(bbox) + ' ' + 'OCR with format: '
        else:
            qs = str(bbox) + ' ' + 'OCR: '

    if color:
        if type_ == 'format':
            qs = '[' + color + ']' + ' ' + 'OCR with format: '
        else:
            qs = '[' + color + ']' + ' ' + 'OCR: '

    if use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs 
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    args_conv_mode = conv_mode
    conv = conv_templates[args_conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    inputs = tokenizer([prompt])

    image_1 = image.copy()
    image_tensor = image_processor(image)
    image_tensor_1 = image_processor_high(image_1)

    input_ids = torch.as_tensor(inputs.input_ids).cuda()

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.autocast("cuda", dtype=torch.float16):
        output_ids = model.generate(
            input_ids,
            images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
            do_sample=False,
            num_beams=1,
            no_repeat_ngram_size=20,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=[stopping_criteria]
        )

    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    
    if outputs.endswith(stop_str):
        outputs = outputs[:-len(stop_str)]
    outputs = outputs.strip()

    if render:
        html_path_2 = os.path.join("/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results", generate_unique_filename())

        if '**kern' in outputs:
            import verovio
            from cairosvg import svg2png
            import cv2
            import numpy as np
            tk = verovio.toolkit()
            tk.loadData(outputs)
            tk.setOptions({"pageWidth": 2100, "footer": 'none',
                           'barLineWidth': 0.5, 'beamMaxSlope': 15,
                           'staffLineWidth': 0.2, 'spacingStaff': 6})
            tk.getPageCount()
            svg = tk.renderToSVG()
            svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")

            svg_to_html(svg, html_path_2)

        if type_ == 'format' and '**kern' not in outputs:
            if '\\begin{tikzpicture}' not in outputs:
                html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "content-mmd-to-html.html"
                right_num = outputs.count('\\right')
                left_num = outputs.count('\left')

                if right_num != left_num:
                    outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')

                outputs = outputs.replace('"', '``').replace('$', '')

                outputs_list = outputs.split('\n')
                gt = ''
                for out in outputs_list:
                    gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
                gt = gt[:-2]

                with open(html_path, 'r') as web_f:
                    lines = web_f.read()
                    lines = lines.split("const text =")
                    new_web = lines[0] + 'const text =' + gt + lines[1]
            else:
                html_path = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/render_tools/" + "tikz.html"
                outputs = outputs.translate(translation_table)
                outputs_list = outputs.split('\n')
                gt = ''
                for out in outputs_list:
                    if out:
                        if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
                            while out[-1] == ' ':
                                out = out[:-1]
                                if out is None:
                                    break

                            if out:
                                if out[-1] != ';':
                                    gt += out[:-1] + ';\n'
                                else:
                                    gt += out + '\n'
                        else:
                            gt += out + '\n'

                with open(html_path, 'r') as web_f:
                    lines = web_f.read()
                    lines = lines.split("const text =")
                    new_web = lines[0] + gt + lines[1]

            with open(html_path_2, 'w') as web_f_new:
                web_f_new.write(new_web)

        base_url = "http://188.1.1.122:5000/results/"
        full_html_path = base_url + os.path.basename(html_path_2)
        return jsonify({"result": "Rendering complete", "outputs": outputs, "html_path": full_html_path})
    else:
        return jsonify({"result": "Processing complete", "outputs": outputs})

@app.route('/results/<path:filename>', methods=['GET'])
def serve_result(filename):
    logger.info(f"Request: {request.method} {request.url}")
    results_dir = "/opt/ollamaModels/GOT-OCR2.0-main/GOT-OCR-2.0-master/results"
    return send_from_directory(results_dir, filename)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)

接口访问示例:
curl --location 'http://1.1.1.1:5000/ocr'
--form 'type="format"'
--form 'image_file=@"/Users/zhiaiyahong/Downloads/wodeshefenzheng_副本.jpg"'
--form 'render="true"'

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容