Files
2025-08-16 20:31:38 +08:00

185 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import fitz
import cv2
import numpy as np
from loguru import logger
from flask import Flask
from flask import request
from PIL import Image
import time
import subprocess
from paddleocr import PaddleOCR, PPStructure, draw_structure_result, save_structure_res
import logging
from pre_processor import pre_process
logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)-4s %(message)s")
# 我们对类别定义如下和paddle的不同
categories=[ # 类别提前定好每一类有一个固定的id
{"id": 1,"name": "Title"},
{"id": 2,"name": "Heading"},
{"id": 3,"name": "Text"},
{"id": 4,"name": "List"},
{"id": 5,"name": "Table"},
{"id": 6,"name": "Figure"},
{"id": 7,"name": "FigureCaption"},
{"id": 8,"name": "TableCaption"},
{"id": 9,"name": "Header"},
{"id": 10,"name": "Footer"},
{"id": 11,"name": "Reference"},
{"id": 12,"name": "Equation"},
{"id": 13,"name": "Toc"}]
DET_MODEL_DIR = os.getenv("DET_MODEL_DIR", "models/ch_PP-OCRv4_det_infer")
REC_MODEL_DIR = os.getenv("REC_MODEL_DIR", "models/ch_PP-OCRv4_rec_infer")
CLS_MODEL_DIR = os.getenv("CLS_MODEL_DIR", "models/ch_ppocr_mobile_v2.0_cls_infer")
lang = os.environ.get("LANGUAGE", "zh")
WITH_PREPROCESSING = os.getenv("WITH_PREPROCESSING", "False").lower() == "true"
PORT = int(os.getenv("PORT", 80))
app = Flask(__name__)
app.config['DEBUG'] = True
app.logger.setLevel(logging.INFO)
def convert_lang(lang):
if lang == "zh":
return "ch"
else:
return lang
logging.info(f"DET_MODEL_DIR: {DET_MODEL_DIR}, REC_MODEL_DIR: {REC_MODEL_DIR}, CLS_MODEL_DIR: {CLS_MODEL_DIR}, lang: {lang}, WITH_PREPROCESSING: {WITH_PREPROCESSING}")
import torch
logging.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
subprocess.run("ixsmi", shell=True, text=True)
subprocess.run("ls -l /dev", shell=True, text=True)
ocr_engine = PaddleOCR(show_log=True, mix=False, lang=convert_lang(lang),
det_model_dir=DET_MODEL_DIR,
rec_model_dir=REC_MODEL_DIR,
cls_model_dir=CLS_MODEL_DIR)
# 调用模型的函数
def ppocr_infer(img):
logging.info(lang)
result = ocr_engine.ocr(img)
return result
def scale_bounding_box(points, scaling_size):
# logging.warning(f"初始检测框:{points}")
# 计算原始检测框的宽度和高度
x_min = points[0][0]
y_min = points[0][1]
x_max = points[2][0]
y_max = points[2][1]
# FIXME(zhanghao): no scale
return [x_min, y_min, x_max, y_max]
original_width = x_max - x_min
original_height = y_max - y_min
# 计算中心点
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
# 应用缩放因子
new_width = original_width * scaling_size
new_height = original_height * scaling_size
# 计算新的边界坐标
new_x_min = center_x - new_width / 2
new_y_min = center_y - new_height / 2
new_x_max = center_x + new_width / 2
new_y_max = center_y + new_height / 2
bbox = [new_x_min, new_y_min, new_x_max, new_y_max]
# logging.warning(f"缩放检测框:{bbox}")
# 返回新的检测框坐标
return bbox
def do_predict_img(img_path):
start_time = time.time()
image = cv2.imread(img_path)
if WITH_PREPROCESSING:
processed_img = pre_process(image)
logging.info(f"Preprocessing takes {time.time() - start_time} s")
else:
processed_img = image
logging.info("Skip Preprocessing")
start_time = time.time()
result = ppocr_infer(processed_img)[0]
logging.info(f"ppocr_infer takes {time.time() - start_time} s")
# logging.info(f"result: {result}")
ans = []
boxs = []
if result:
box_num = len(result)
for i in range(box_num):
text = result[i][1][0]
score = result[i][1][1]
bbox = scale_bounding_box(result[i][0], 0.80)
if score > 0.8:
data = {
"bbox": bbox,
"type": "Text",
"content": text,
"page": 1,
"score": score
}
# if i == 0:
# logging.info(f"data sample is {data}.")
ans.append(data)
boxs.append(bbox)
# logging.info("ans:", ans)
logging.info(str(len(ans)))
if len(ans) == 0:
img = Image.open(img_path)
data = {
"bbox": [0, 0, img.width, img.height],
"type": 'Text',
"content": "",
"page": 1,
"score": 1.0
}
ans.append(data)
boxs.append([0, 0, img.width, img.height])
return ans, boxs
def draw_bboxes(img_path, bboxes):
img = cv2.imread(img_path)
for bbox in bboxes:
x_min, y_min, x_max, y_max = bbox
cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) # 绘制绿色矩形,线宽为 2
return img
@app.route("/predict", methods=["POST"])
def predict():
logging.info(str(request.files.keys()))
if "pdf" in request.files:
return {"success": True, "result": []}
else:
file = request.files['image']
name = request.form.get("image_name") # 带后缀的
path = f'./temp_docs/{name}' # 存储路径
start_time = time.time()
file.save(path) # 保存
logging.info(f"Saving file takes {time.time() - start_time} s")
# 处理收到的pdf获取预测结果并返回
start_time = time.time()
res, boxss = do_predict_img(path)
logging.info(f"Predict takes {time.time() - start_time} s")
# img_with_bboxes = draw_bboxes(path, boxss)
# cv2.imwrite('./outputest_rapid.jpg', img_with_bboxes) # 保存绘制好 bbox 的图像
return {"success": True, "result": res if res is not None else []}
if __name__ == '__main__':
app.run("0.0.0.0", PORT, use_reloader=False)