185 lines
5.9 KiB
Python
185 lines
5.9 KiB
Python
import os
|
||
import fitz
|
||
import cv2
|
||
import numpy as np
|
||
from loguru import logger
|
||
from flask import Flask
|
||
from flask import request
|
||
from PIL import Image
|
||
import time
|
||
import subprocess
|
||
from paddleocr import PaddleOCR, PPStructure, draw_structure_result, save_structure_res
|
||
import logging
|
||
from pre_processor import pre_process
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)-4s %(message)s")
|
||
|
||
# 我们对类别定义如下,和paddle的不同
|
||
categories=[ # 类别提前定好,每一类有一个固定的id
|
||
{"id": 1,"name": "Title"},
|
||
{"id": 2,"name": "Heading"},
|
||
{"id": 3,"name": "Text"},
|
||
{"id": 4,"name": "List"},
|
||
{"id": 5,"name": "Table"},
|
||
{"id": 6,"name": "Figure"},
|
||
{"id": 7,"name": "FigureCaption"},
|
||
{"id": 8,"name": "TableCaption"},
|
||
{"id": 9,"name": "Header"},
|
||
{"id": 10,"name": "Footer"},
|
||
{"id": 11,"name": "Reference"},
|
||
{"id": 12,"name": "Equation"},
|
||
{"id": 13,"name": "Toc"}]
|
||
|
||
DET_MODEL_DIR = os.getenv("DET_MODEL_DIR", "models/ch_PP-OCRv4_det_infer")
|
||
REC_MODEL_DIR = os.getenv("REC_MODEL_DIR", "models/ch_PP-OCRv4_rec_infer")
|
||
CLS_MODEL_DIR = os.getenv("CLS_MODEL_DIR", "models/ch_ppocr_mobile_v2.0_cls_infer")
|
||
lang = os.environ.get("LANGUAGE", "zh")
|
||
WITH_PREPROCESSING = os.getenv("WITH_PREPROCESSING", "False").lower() == "true"
|
||
PORT = int(os.getenv("PORT", 80))
|
||
|
||
app = Flask(__name__)
|
||
app.config['DEBUG'] = True
|
||
app.logger.setLevel(logging.INFO)
|
||
|
||
def convert_lang(lang):
|
||
if lang == "zh":
|
||
return "ch"
|
||
else:
|
||
return lang
|
||
|
||
logging.info(f"DET_MODEL_DIR: {DET_MODEL_DIR}, REC_MODEL_DIR: {REC_MODEL_DIR}, CLS_MODEL_DIR: {CLS_MODEL_DIR}, lang: {lang}, WITH_PREPROCESSING: {WITH_PREPROCESSING}")
|
||
import torch
|
||
logging.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
|
||
subprocess.run("ixsmi", shell=True, text=True)
|
||
subprocess.run("ls -l /dev", shell=True, text=True)
|
||
|
||
ocr_engine = PaddleOCR(show_log=True, mix=False, lang=convert_lang(lang),
|
||
det_model_dir=DET_MODEL_DIR,
|
||
rec_model_dir=REC_MODEL_DIR,
|
||
cls_model_dir=CLS_MODEL_DIR)
|
||
|
||
# 调用模型的函数
|
||
def ppocr_infer(img):
|
||
logging.info(lang)
|
||
result = ocr_engine.ocr(img)
|
||
return result
|
||
|
||
def scale_bounding_box(points, scaling_size):
|
||
# logging.warning(f"初始检测框:{points}")
|
||
# 计算原始检测框的宽度和高度
|
||
x_min = points[0][0]
|
||
y_min = points[0][1]
|
||
x_max = points[2][0]
|
||
y_max = points[2][1]
|
||
|
||
# FIXME(zhanghao): no scale
|
||
return [x_min, y_min, x_max, y_max]
|
||
|
||
original_width = x_max - x_min
|
||
original_height = y_max - y_min
|
||
|
||
# 计算中心点
|
||
center_x = (x_min + x_max) / 2
|
||
center_y = (y_min + y_max) / 2
|
||
|
||
# 应用缩放因子
|
||
new_width = original_width * scaling_size
|
||
new_height = original_height * scaling_size
|
||
|
||
# 计算新的边界坐标
|
||
new_x_min = center_x - new_width / 2
|
||
new_y_min = center_y - new_height / 2
|
||
new_x_max = center_x + new_width / 2
|
||
new_y_max = center_y + new_height / 2
|
||
|
||
bbox = [new_x_min, new_y_min, new_x_max, new_y_max]
|
||
# logging.warning(f"缩放检测框:{bbox}")
|
||
# 返回新的检测框坐标
|
||
return bbox
|
||
|
||
def do_predict_img(img_path):
|
||
start_time = time.time()
|
||
image = cv2.imread(img_path)
|
||
|
||
if WITH_PREPROCESSING:
|
||
processed_img = pre_process(image)
|
||
logging.info(f"Preprocessing takes {time.time() - start_time} s")
|
||
else:
|
||
processed_img = image
|
||
logging.info("Skip Preprocessing")
|
||
|
||
start_time = time.time()
|
||
result = ppocr_infer(processed_img)[0]
|
||
logging.info(f"ppocr_infer takes {time.time() - start_time} s")
|
||
# logging.info(f"result: {result}")
|
||
ans = []
|
||
boxs = []
|
||
if result:
|
||
box_num = len(result)
|
||
for i in range(box_num):
|
||
text = result[i][1][0]
|
||
score = result[i][1][1]
|
||
|
||
bbox = scale_bounding_box(result[i][0], 0.80)
|
||
if score > 0.8:
|
||
data = {
|
||
"bbox": bbox,
|
||
"type": "Text",
|
||
"content": text,
|
||
"page": 1,
|
||
"score": score
|
||
}
|
||
|
||
# if i == 0:
|
||
# logging.info(f"data sample is {data}.")
|
||
|
||
ans.append(data)
|
||
boxs.append(bbox)
|
||
|
||
# logging.info("ans:", ans)
|
||
logging.info(str(len(ans)))
|
||
|
||
if len(ans) == 0:
|
||
img = Image.open(img_path)
|
||
data = {
|
||
"bbox": [0, 0, img.width, img.height],
|
||
"type": 'Text',
|
||
"content": "",
|
||
"page": 1,
|
||
"score": 1.0
|
||
}
|
||
ans.append(data)
|
||
boxs.append([0, 0, img.width, img.height])
|
||
return ans, boxs
|
||
|
||
def draw_bboxes(img_path, bboxes):
|
||
img = cv2.imread(img_path)
|
||
for bbox in bboxes:
|
||
x_min, y_min, x_max, y_max = bbox
|
||
cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) # 绘制绿色矩形,线宽为 2
|
||
return img
|
||
|
||
@app.route("/predict", methods=["POST"])
|
||
def predict():
|
||
logging.info(str(request.files.keys()))
|
||
if "pdf" in request.files:
|
||
return {"success": True, "result": []}
|
||
else:
|
||
file = request.files['image']
|
||
name = request.form.get("image_name") # 带后缀的
|
||
path = f'./temp_docs/{name}' # 存储路径
|
||
start_time = time.time()
|
||
file.save(path) # 保存
|
||
logging.info(f"Saving file takes {time.time() - start_time} s")
|
||
# 处理收到的pdf,获取预测结果并返回
|
||
start_time = time.time()
|
||
res, boxss = do_predict_img(path)
|
||
logging.info(f"Predict takes {time.time() - start_time} s")
|
||
# img_with_bboxes = draw_bboxes(path, boxss)
|
||
# cv2.imwrite('./outputest_rapid.jpg', img_with_bboxes) # 保存绘制好 bbox 的图像
|
||
|
||
return {"success": True, "result": res if res is not None else []}
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run("0.0.0.0", PORT, use_reloader=False)
|