Update mix_inference.py

替换文本OCR模型为paddleocr
This commit is contained in:
TonyLee1256
2024-05-09 00:23:02 +08:00
committed by GitHub
parent bd2aaa3e00
commit 83da4262fd

View File

@@ -7,9 +7,7 @@ from collections import Counter
from typing import List
from PIL import Image
from surya.detection import batch_text_detection
from surya.input.processing import slice_polys_from_image
from surya.recognition import batch_recognition
from paddleocr.ppocr.utils.utility import alpha_to_color
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
@@ -126,6 +124,21 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y:y+h, x:x+w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def preprocess_image(_image):
_image = alpha_to_color(_image, (255, 255, 255))
return _image
def mix_inference(
img_path: str,
language: str,
@@ -143,6 +156,7 @@ def mix_inference(
'''
global img
img = cv2.imread(img_path)
img = alpha_to_color(img, (255, 255, 255))
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
@@ -156,50 +170,33 @@ def mix_inference(
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, det_processor, rec_model, rec_processor = lang_ocr_models
images = [Image.fromarray(masked_img)]
det_prediction = batch_text_detection(images, det_model, det_processor)[0]
det_model, rec_model = lang_ocr_models
det_prediction, _ = det_model(masked_img)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
lang = [language]
slice_map = []
all_slices = []
all_langs = []
ocr_bboxes = [
Bbox(
p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0],
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
label="text",
confidence=p.confidence,
confidence=None,
content=None
)
for p in det_prediction.bboxes
for p in det_prediction
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
polygons = [
[
[bbox.ul_point.x, bbox.ul_point.y],
[bbox.ur_point.x, bbox.ur_point.y],
[bbox.lr_point.x, bbox.lr_point.y],
[bbox.ll_point.x, bbox.ll_point.y]
]
for bbox in ocr_bboxes
]
slices = slice_polys_from_image(images[0], polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
rec_predictions, _ = rec_model(sliced_imgs)
rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor)
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content
bbox.content = content[0]
latex_imgs =[]
for bbox in latex_bboxes: