import os import fitz import cv2 import numpy as np from loguru import logger from flask import Flask from flask import request from PIL import Image import time import subprocess from paddleocr import PaddleOCR, PPStructure, draw_structure_result, save_structure_res import logging from pre_processor import pre_process logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)-4s %(message)s") # 我们对类别定义如下,和paddle的不同 categories=[ # 类别提前定好,每一类有一个固定的id {"id": 1,"name": "Title"}, {"id": 2,"name": "Heading"}, {"id": 3,"name": "Text"}, {"id": 4,"name": "List"}, {"id": 5,"name": "Table"}, {"id": 6,"name": "Figure"}, {"id": 7,"name": "FigureCaption"}, {"id": 8,"name": "TableCaption"}, {"id": 9,"name": "Header"}, {"id": 10,"name": "Footer"}, {"id": 11,"name": "Reference"}, {"id": 12,"name": "Equation"}, {"id": 13,"name": "Toc"}] DET_MODEL_DIR = os.getenv("DET_MODEL_DIR", "models/ch_PP-OCRv4_det_infer") REC_MODEL_DIR = os.getenv("REC_MODEL_DIR", "models/ch_PP-OCRv4_rec_infer") CLS_MODEL_DIR = os.getenv("CLS_MODEL_DIR", "models/ch_ppocr_mobile_v2.0_cls_infer") lang = os.environ.get("LANGUAGE", "zh") WITH_PREPROCESSING = os.getenv("WITH_PREPROCESSING", "False").lower() == "true" PORT = int(os.getenv("PORT", 80)) app = Flask(__name__) app.config['DEBUG'] = True app.logger.setLevel(logging.INFO) def convert_lang(lang): if lang == "zh": return "ch" else: return lang logging.info(f"DET_MODEL_DIR: {DET_MODEL_DIR}, REC_MODEL_DIR: {REC_MODEL_DIR}, CLS_MODEL_DIR: {CLS_MODEL_DIR}, lang: {lang}, WITH_PREPROCESSING: {WITH_PREPROCESSING}") import torch logging.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}") subprocess.run("ixsmi", shell=True, text=True) subprocess.run("ls -l /dev", shell=True, text=True) ocr_engine = PaddleOCR(show_log=True, mix=False, lang=convert_lang(lang), det_model_dir=DET_MODEL_DIR, rec_model_dir=REC_MODEL_DIR, cls_model_dir=CLS_MODEL_DIR) # 调用模型的函数 def ppocr_infer(img): logging.info(lang) result = ocr_engine.ocr(img) return result def scale_bounding_box(points, scaling_size): # logging.warning(f"初始检测框:{points}") # 计算原始检测框的宽度和高度 x_min = points[0][0] y_min = points[0][1] x_max = points[2][0] y_max = points[2][1] # FIXME(zhanghao): no scale return [x_min, y_min, x_max, y_max] original_width = x_max - x_min original_height = y_max - y_min # 计算中心点 center_x = (x_min + x_max) / 2 center_y = (y_min + y_max) / 2 # 应用缩放因子 new_width = original_width * scaling_size new_height = original_height * scaling_size # 计算新的边界坐标 new_x_min = center_x - new_width / 2 new_y_min = center_y - new_height / 2 new_x_max = center_x + new_width / 2 new_y_max = center_y + new_height / 2 bbox = [new_x_min, new_y_min, new_x_max, new_y_max] # logging.warning(f"缩放检测框:{bbox}") # 返回新的检测框坐标 return bbox def do_predict_img(img_path): start_time = time.time() image = cv2.imread(img_path) if WITH_PREPROCESSING: processed_img = pre_process(image) logging.info(f"Preprocessing takes {time.time() - start_time} s") else: processed_img = image logging.info("Skip Preprocessing") start_time = time.time() result = ppocr_infer(processed_img)[0] logging.info(f"ppocr_infer takes {time.time() - start_time} s") # logging.info(f"result: {result}") ans = [] boxs = [] if result: box_num = len(result) for i in range(box_num): text = result[i][1][0] score = result[i][1][1] bbox = scale_bounding_box(result[i][0], 0.80) if score > 0.8: data = { "bbox": bbox, "type": "Text", "content": text, "page": 1, "score": score } # if i == 0: # logging.info(f"data sample is {data}.") ans.append(data) boxs.append(bbox) # logging.info("ans:", ans) logging.info(str(len(ans))) if len(ans) == 0: img = Image.open(img_path) data = { "bbox": [0, 0, img.width, img.height], "type": 'Text', "content": "", "page": 1, "score": 1.0 } ans.append(data) boxs.append([0, 0, img.width, img.height]) return ans, boxs def draw_bboxes(img_path, bboxes): img = cv2.imread(img_path) for bbox in bboxes: x_min, y_min, x_max, y_max = bbox cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) # 绘制绿色矩形,线宽为 2 return img @app.route("/predict", methods=["POST"]) def predict(): logging.info(str(request.files.keys())) if "pdf" in request.files: return {"success": True, "result": []} else: file = request.files['image'] name = request.form.get("image_name") # 带后缀的 path = f'./temp_docs/{name}' # 存储路径 start_time = time.time() file.save(path) # 保存 logging.info(f"Saving file takes {time.time() - start_time} s") # 处理收到的pdf,获取预测结果并返回 start_time = time.time() res, boxss = do_predict_img(path) logging.info(f"Predict takes {time.time() - start_time} s") # img_with_bboxes = draw_bboxes(path, boxss) # cv2.imwrite('./outputest_rapid.jpg', img_with_bboxes) # 保存绘制好 bbox 的图像 return {"success": True, "result": res if res is not None else []} if __name__ == '__main__': app.run("0.0.0.0", PORT, use_reloader=False)