From 83da4262fdc56d68c3960c5b5f625892e016b46f Mon Sep 17 00:00:00 2001 From: TonyLee1256 <163754792+TonyLee1256@users.noreply.github.com> Date: Thu, 9 May 2024 00:23:02 +0800 Subject: [PATCH] Update mix_inference.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 替换文本OCR模型为paddleocr --- src/models/utils/mix_inference.py | 53 +++++++++++++++---------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/models/utils/mix_inference.py b/src/models/utils/mix_inference.py index 8a96d0c..398bf6c 100644 --- a/src/models/utils/mix_inference.py +++ b/src/models/utils/mix_inference.py @@ -7,9 +7,7 @@ from collections import Counter from typing import List from PIL import Image -from surya.detection import batch_text_detection -from surya.input.processing import slice_polys_from_image -from surya.recognition import batch_recognition +from paddleocr.ppocr.utils.utility import alpha_to_color from ..det_model.inference import predict as latex_det_predict from ..det_model.Bbox import Bbox, draw_bboxes @@ -126,6 +124,21 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo return res +def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]: + sliced_imgs = [] + for bbox in ocr_bboxes: + x, y = int(bbox.p.x), int(bbox.p.y) + w, h = int(bbox.w), int(bbox.h) + sliced_img = img[y:y+h, x:x+w] + sliced_imgs.append(sliced_img) + return sliced_imgs + + +def preprocess_image(_image): + _image = alpha_to_color(_image, (255, 255, 255)) + return _image + + def mix_inference( img_path: str, language: str, @@ -143,6 +156,7 @@ def mix_inference( ''' global img img = cv2.imread(img_path) + img = alpha_to_color(img, (255, 255, 255)) corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] bg_color = np.array(Counter(corners).most_common(1)[0][0]) @@ -156,50 +170,33 @@ def mix_inference( draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png") masked_img = mask_img(img, latex_bboxes, bg_color) - det_model, det_processor, rec_model, rec_processor = lang_ocr_models - images = [Image.fromarray(masked_img)] - det_prediction = batch_text_detection(images, det_model, det_processor)[0] + det_model, rec_model = lang_ocr_models + det_prediction, _ = det_model(masked_img) # log results draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png") - lang = [language] - slice_map = [] - all_slices = [] - all_langs = [] ocr_bboxes = [ Bbox( - p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0], + p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0], label="text", - confidence=p.confidence, + confidence=None, content=None ) - for p in det_prediction.bboxes + for p in det_prediction ] ocr_bboxes = sorted(ocr_bboxes) ocr_bboxes = bbox_merge(ocr_bboxes) draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png") ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) - polygons = [ - [ - [bbox.ul_point.x, bbox.ul_point.y], - [bbox.ur_point.x, bbox.ur_point.y], - [bbox.lr_point.x, bbox.lr_point.y], - [bbox.ll_point.x, bbox.ll_point.y] - ] - for bbox in ocr_bboxes - ] - slices = slice_polys_from_image(images[0], polygons) - slice_map.append(len(slices)) - all_slices.extend(slices) - all_langs.extend([lang] * len(slices)) + sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) + rec_predictions, _ = rec_model(sliced_imgs) - rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor) assert len(rec_predictions) == len(ocr_bboxes) for content, bbox in zip(rec_predictions, ocr_bboxes): - bbox.content = content + bbox.content = content[0] latex_imgs =[] for bbox in latex_bboxes: