From 621bebc473ea3d4fe4d101d8d5042c4179ff2406 Mon Sep 17 00:00:00 2001 From: Zhang Hao Date: Sat, 16 Aug 2025 20:31:38 +0800 Subject: [PATCH] add ocr service --- README.md | 3 +- app.py | 184 +++++++++++++++++++++++++++++++++++++++++++++++ pre_processor.py | 23 ++++++ requirements.txt | 5 ++ 4 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 app.py create mode 100644 pre_processor.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 212ee4f..eaa4509 100644 --- a/README.md +++ b/README.md @@ -1,2 +1 @@ -# enginex-bi_series-rapidocr - +# enginex-bi_series-paddleocr diff --git a/app.py b/app.py new file mode 100644 index 0000000..c85364f --- /dev/null +++ b/app.py @@ -0,0 +1,184 @@ +import os +import fitz +import cv2 +import numpy as np +from loguru import logger +from flask import Flask +from flask import request +from PIL import Image +import time +import subprocess +from paddleocr import PaddleOCR, PPStructure, draw_structure_result, save_structure_res +import logging +from pre_processor import pre_process +logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)-4s %(message)s") + +# 我们对类别定义如下,和paddle的不同 +categories=[ # 类别提前定好,每一类有一个固定的id +{"id": 1,"name": "Title"}, +{"id": 2,"name": "Heading"}, +{"id": 3,"name": "Text"}, +{"id": 4,"name": "List"}, +{"id": 5,"name": "Table"}, +{"id": 6,"name": "Figure"}, +{"id": 7,"name": "FigureCaption"}, +{"id": 8,"name": "TableCaption"}, +{"id": 9,"name": "Header"}, +{"id": 10,"name": "Footer"}, +{"id": 11,"name": "Reference"}, +{"id": 12,"name": "Equation"}, +{"id": 13,"name": "Toc"}] + +DET_MODEL_DIR = os.getenv("DET_MODEL_DIR", "models/ch_PP-OCRv4_det_infer") +REC_MODEL_DIR = os.getenv("REC_MODEL_DIR", "models/ch_PP-OCRv4_rec_infer") +CLS_MODEL_DIR = os.getenv("CLS_MODEL_DIR", "models/ch_ppocr_mobile_v2.0_cls_infer") +lang = os.environ.get("LANGUAGE", "zh") +WITH_PREPROCESSING = os.getenv("WITH_PREPROCESSING", "False").lower() == "true" +PORT = int(os.getenv("PORT", 80)) + +app = Flask(__name__) +app.config['DEBUG'] = True +app.logger.setLevel(logging.INFO) + +def convert_lang(lang): + if lang == "zh": + return "ch" + else: + return lang + +logging.info(f"DET_MODEL_DIR: {DET_MODEL_DIR}, REC_MODEL_DIR: {REC_MODEL_DIR}, CLS_MODEL_DIR: {CLS_MODEL_DIR}, lang: {lang}, WITH_PREPROCESSING: {WITH_PREPROCESSING}") +import torch +logging.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}") +subprocess.run("ixsmi", shell=True, text=True) +subprocess.run("ls -l /dev", shell=True, text=True) + +ocr_engine = PaddleOCR(show_log=True, mix=False, lang=convert_lang(lang), + det_model_dir=DET_MODEL_DIR, + rec_model_dir=REC_MODEL_DIR, + cls_model_dir=CLS_MODEL_DIR) + +# 调用模型的函数 +def ppocr_infer(img): + logging.info(lang) + result = ocr_engine.ocr(img) + return result + +def scale_bounding_box(points, scaling_size): + # logging.warning(f"初始检测框:{points}") + # 计算原始检测框的宽度和高度 + x_min = points[0][0] + y_min = points[0][1] + x_max = points[2][0] + y_max = points[2][1] + + # FIXME(zhanghao): no scale + return [x_min, y_min, x_max, y_max] + + original_width = x_max - x_min + original_height = y_max - y_min + + # 计算中心点 + center_x = (x_min + x_max) / 2 + center_y = (y_min + y_max) / 2 + + # 应用缩放因子 + new_width = original_width * scaling_size + new_height = original_height * scaling_size + + # 计算新的边界坐标 + new_x_min = center_x - new_width / 2 + new_y_min = center_y - new_height / 2 + new_x_max = center_x + new_width / 2 + new_y_max = center_y + new_height / 2 + + bbox = [new_x_min, new_y_min, new_x_max, new_y_max] + # logging.warning(f"缩放检测框:{bbox}") + # 返回新的检测框坐标 + return bbox + +def do_predict_img(img_path): + start_time = time.time() + image = cv2.imread(img_path) + + if WITH_PREPROCESSING: + processed_img = pre_process(image) + logging.info(f"Preprocessing takes {time.time() - start_time} s") + else: + processed_img = image + logging.info("Skip Preprocessing") + + start_time = time.time() + result = ppocr_infer(processed_img)[0] + logging.info(f"ppocr_infer takes {time.time() - start_time} s") + # logging.info(f"result: {result}") + ans = [] + boxs = [] + if result: + box_num = len(result) + for i in range(box_num): + text = result[i][1][0] + score = result[i][1][1] + + bbox = scale_bounding_box(result[i][0], 0.80) + if score > 0.8: + data = { + "bbox": bbox, + "type": "Text", + "content": text, + "page": 1, + "score": score + } + + # if i == 0: + # logging.info(f"data sample is {data}.") + + ans.append(data) + boxs.append(bbox) + + # logging.info("ans:", ans) + logging.info(str(len(ans))) + + if len(ans) == 0: + img = Image.open(img_path) + data = { + "bbox": [0, 0, img.width, img.height], + "type": 'Text', + "content": "", + "page": 1, + "score": 1.0 + } + ans.append(data) + boxs.append([0, 0, img.width, img.height]) + return ans, boxs + +def draw_bboxes(img_path, bboxes): + img = cv2.imread(img_path) + for bbox in bboxes: + x_min, y_min, x_max, y_max = bbox + cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) # 绘制绿色矩形,线宽为 2 + return img + +@app.route("/predict", methods=["POST"]) +def predict(): + logging.info(str(request.files.keys())) + if "pdf" in request.files: + return {"success": True, "result": []} + else: + file = request.files['image'] + name = request.form.get("image_name") # 带后缀的 + path = f'./temp_docs/{name}' # 存储路径 + start_time = time.time() + file.save(path) # 保存 + logging.info(f"Saving file takes {time.time() - start_time} s") + # 处理收到的pdf,获取预测结果并返回 + start_time = time.time() + res, boxss = do_predict_img(path) + logging.info(f"Predict takes {time.time() - start_time} s") + # img_with_bboxes = draw_bboxes(path, boxss) + # cv2.imwrite('./outputest_rapid.jpg', img_with_bboxes) # 保存绘制好 bbox 的图像 + + return {"success": True, "result": res if res is not None else []} + + +if __name__ == '__main__': + app.run("0.0.0.0", PORT, use_reloader=False) diff --git a/pre_processor.py b/pre_processor.py new file mode 100644 index 0000000..54e8007 --- /dev/null +++ b/pre_processor.py @@ -0,0 +1,23 @@ +import cv2 +import numpy as np +import os + +NOISE_THRESHOLD = int(os.getenv("NOISE_THRESHOLD", "10")) # 示例阈值,单位为噪声的标准差 + +def pre_process(image): + if len(image.shape) == 3: # 如果是彩色图像 (具有三个通道) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + else: + gray = image + + # 噪声水平估计(简单地使用标准差作为代理) + noise_level = np.std(gray) # 在灰度图像上评估噪声水平 + if noise_level > NOISE_THRESHOLD: + if len(image.shape) == 3: # 如果是彩色图像 + processed_img = cv2.fastNlMeansDenoisingColored(image) + else: # 如果是灰度图像 + processed_img = cv2.fastNlMeansDenoising(image) + else: + processed_img = image + + return processed_img \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b0ac283 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +paddleocr==2.6.1 +pillow==10.2.0 +opencv-python==4.6.0.66 +Flask==3.0.2 +loguru==0.7.2 \ No newline at end of file