add ocr service
This commit is contained in:
184
app.py
Normal file
184
app.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
import fitz
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from flask import Flask
|
||||
from flask import request
|
||||
from PIL import Image
|
||||
import time
|
||||
import subprocess
|
||||
from paddleocr import PaddleOCR, PPStructure, draw_structure_result, save_structure_res
|
||||
import logging
|
||||
from pre_processor import pre_process
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)-4s %(message)s")
|
||||
|
||||
# 我们对类别定义如下,和paddle的不同
|
||||
categories=[ # 类别提前定好,每一类有一个固定的id
|
||||
{"id": 1,"name": "Title"},
|
||||
{"id": 2,"name": "Heading"},
|
||||
{"id": 3,"name": "Text"},
|
||||
{"id": 4,"name": "List"},
|
||||
{"id": 5,"name": "Table"},
|
||||
{"id": 6,"name": "Figure"},
|
||||
{"id": 7,"name": "FigureCaption"},
|
||||
{"id": 8,"name": "TableCaption"},
|
||||
{"id": 9,"name": "Header"},
|
||||
{"id": 10,"name": "Footer"},
|
||||
{"id": 11,"name": "Reference"},
|
||||
{"id": 12,"name": "Equation"},
|
||||
{"id": 13,"name": "Toc"}]
|
||||
|
||||
DET_MODEL_DIR = os.getenv("DET_MODEL_DIR", "models/ch_PP-OCRv4_det_infer")
|
||||
REC_MODEL_DIR = os.getenv("REC_MODEL_DIR", "models/ch_PP-OCRv4_rec_infer")
|
||||
CLS_MODEL_DIR = os.getenv("CLS_MODEL_DIR", "models/ch_ppocr_mobile_v2.0_cls_infer")
|
||||
lang = os.environ.get("LANGUAGE", "zh")
|
||||
WITH_PREPROCESSING = os.getenv("WITH_PREPROCESSING", "False").lower() == "true"
|
||||
PORT = int(os.getenv("PORT", 80))
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['DEBUG'] = True
|
||||
app.logger.setLevel(logging.INFO)
|
||||
|
||||
def convert_lang(lang):
|
||||
if lang == "zh":
|
||||
return "ch"
|
||||
else:
|
||||
return lang
|
||||
|
||||
logging.info(f"DET_MODEL_DIR: {DET_MODEL_DIR}, REC_MODEL_DIR: {REC_MODEL_DIR}, CLS_MODEL_DIR: {CLS_MODEL_DIR}, lang: {lang}, WITH_PREPROCESSING: {WITH_PREPROCESSING}")
|
||||
import torch
|
||||
logging.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
|
||||
subprocess.run("ixsmi", shell=True, text=True)
|
||||
subprocess.run("ls -l /dev", shell=True, text=True)
|
||||
|
||||
ocr_engine = PaddleOCR(show_log=True, mix=False, lang=convert_lang(lang),
|
||||
det_model_dir=DET_MODEL_DIR,
|
||||
rec_model_dir=REC_MODEL_DIR,
|
||||
cls_model_dir=CLS_MODEL_DIR)
|
||||
|
||||
# 调用模型的函数
|
||||
def ppocr_infer(img):
|
||||
logging.info(lang)
|
||||
result = ocr_engine.ocr(img)
|
||||
return result
|
||||
|
||||
def scale_bounding_box(points, scaling_size):
|
||||
# logging.warning(f"初始检测框:{points}")
|
||||
# 计算原始检测框的宽度和高度
|
||||
x_min = points[0][0]
|
||||
y_min = points[0][1]
|
||||
x_max = points[2][0]
|
||||
y_max = points[2][1]
|
||||
|
||||
# FIXME(zhanghao): no scale
|
||||
return [x_min, y_min, x_max, y_max]
|
||||
|
||||
original_width = x_max - x_min
|
||||
original_height = y_max - y_min
|
||||
|
||||
# 计算中心点
|
||||
center_x = (x_min + x_max) / 2
|
||||
center_y = (y_min + y_max) / 2
|
||||
|
||||
# 应用缩放因子
|
||||
new_width = original_width * scaling_size
|
||||
new_height = original_height * scaling_size
|
||||
|
||||
# 计算新的边界坐标
|
||||
new_x_min = center_x - new_width / 2
|
||||
new_y_min = center_y - new_height / 2
|
||||
new_x_max = center_x + new_width / 2
|
||||
new_y_max = center_y + new_height / 2
|
||||
|
||||
bbox = [new_x_min, new_y_min, new_x_max, new_y_max]
|
||||
# logging.warning(f"缩放检测框:{bbox}")
|
||||
# 返回新的检测框坐标
|
||||
return bbox
|
||||
|
||||
def do_predict_img(img_path):
|
||||
start_time = time.time()
|
||||
image = cv2.imread(img_path)
|
||||
|
||||
if WITH_PREPROCESSING:
|
||||
processed_img = pre_process(image)
|
||||
logging.info(f"Preprocessing takes {time.time() - start_time} s")
|
||||
else:
|
||||
processed_img = image
|
||||
logging.info("Skip Preprocessing")
|
||||
|
||||
start_time = time.time()
|
||||
result = ppocr_infer(processed_img)[0]
|
||||
logging.info(f"ppocr_infer takes {time.time() - start_time} s")
|
||||
# logging.info(f"result: {result}")
|
||||
ans = []
|
||||
boxs = []
|
||||
if result:
|
||||
box_num = len(result)
|
||||
for i in range(box_num):
|
||||
text = result[i][1][0]
|
||||
score = result[i][1][1]
|
||||
|
||||
bbox = scale_bounding_box(result[i][0], 0.80)
|
||||
if score > 0.8:
|
||||
data = {
|
||||
"bbox": bbox,
|
||||
"type": "Text",
|
||||
"content": text,
|
||||
"page": 1,
|
||||
"score": score
|
||||
}
|
||||
|
||||
# if i == 0:
|
||||
# logging.info(f"data sample is {data}.")
|
||||
|
||||
ans.append(data)
|
||||
boxs.append(bbox)
|
||||
|
||||
# logging.info("ans:", ans)
|
||||
logging.info(str(len(ans)))
|
||||
|
||||
if len(ans) == 0:
|
||||
img = Image.open(img_path)
|
||||
data = {
|
||||
"bbox": [0, 0, img.width, img.height],
|
||||
"type": 'Text',
|
||||
"content": "",
|
||||
"page": 1,
|
||||
"score": 1.0
|
||||
}
|
||||
ans.append(data)
|
||||
boxs.append([0, 0, img.width, img.height])
|
||||
return ans, boxs
|
||||
|
||||
def draw_bboxes(img_path, bboxes):
|
||||
img = cv2.imread(img_path)
|
||||
for bbox in bboxes:
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) # 绘制绿色矩形,线宽为 2
|
||||
return img
|
||||
|
||||
@app.route("/predict", methods=["POST"])
|
||||
def predict():
|
||||
logging.info(str(request.files.keys()))
|
||||
if "pdf" in request.files:
|
||||
return {"success": True, "result": []}
|
||||
else:
|
||||
file = request.files['image']
|
||||
name = request.form.get("image_name") # 带后缀的
|
||||
path = f'./temp_docs/{name}' # 存储路径
|
||||
start_time = time.time()
|
||||
file.save(path) # 保存
|
||||
logging.info(f"Saving file takes {time.time() - start_time} s")
|
||||
# 处理收到的pdf,获取预测结果并返回
|
||||
start_time = time.time()
|
||||
res, boxss = do_predict_img(path)
|
||||
logging.info(f"Predict takes {time.time() - start_time} s")
|
||||
# img_with_bboxes = draw_bboxes(path, boxss)
|
||||
# cv2.imwrite('./outputest_rapid.jpg', img_with_bboxes) # 保存绘制好 bbox 的图像
|
||||
|
||||
return {"success": True, "result": res if res is not None else []}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run("0.0.0.0", PORT, use_reloader=False)
|
||||
23
pre_processor.py
Normal file
23
pre_processor.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
NOISE_THRESHOLD = int(os.getenv("NOISE_THRESHOLD", "10")) # 示例阈值,单位为噪声的标准差
|
||||
|
||||
def pre_process(image):
|
||||
if len(image.shape) == 3: # 如果是彩色图像 (具有三个通道)
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = image
|
||||
|
||||
# 噪声水平估计(简单地使用标准差作为代理)
|
||||
noise_level = np.std(gray) # 在灰度图像上评估噪声水平
|
||||
if noise_level > NOISE_THRESHOLD:
|
||||
if len(image.shape) == 3: # 如果是彩色图像
|
||||
processed_img = cv2.fastNlMeansDenoisingColored(image)
|
||||
else: # 如果是灰度图像
|
||||
processed_img = cv2.fastNlMeansDenoising(image)
|
||||
else:
|
||||
processed_img = image
|
||||
|
||||
return processed_img
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
paddleocr==2.6.1
|
||||
pillow==10.2.0
|
||||
opencv-python==4.6.0.66
|
||||
Flask==3.0.2
|
||||
loguru==0.7.2
|
||||
Reference in New Issue
Block a user