merge dev后调整了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:48:24 +08:00
parent e6dca76123
commit 11df230200
66 changed files with 190 additions and 124855 deletions

View File

@@ -3,24 +3,17 @@ import heapq
import cv2
import numpy as np
from onnxruntime import InferenceSession
from collections import Counter
from typing import List
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
from surya.recognition import batch_recognition
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
from ..det_model.inference import PredictConfig
from surya.detection import batch_text_detection
from surya.input.processing import slice_polys_from_image
from surya.recognition import batch_recognition
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.model.TexTeller import TexTeller
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex
@@ -59,11 +52,10 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
bboxes = sorted(ocr_bboxes + latex_bboxes)
######## debug #########
# log results
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
######## debug ###########
assert len(bboxes) > 1
@@ -125,11 +117,12 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
assert False
res.append(candidate)
res.append(curr)
######## debug #########
# log results
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
######## debug ###########
return res
@@ -156,14 +149,17 @@ def mix_inference(
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
latex_bboxes = sorted(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, det_processor, rec_model, rec_processor = lang_ocr_models
images = [Image.fromarray(masked_img)]
det_prediction = batch_text_detection(images, det_model, det_processor)[0]
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
lang = [language]
@@ -222,7 +218,6 @@ def mix_inference(
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
# prev = bboxes[0]
for curr in bboxes:
if not prev.same_row(curr):
md += "\n"
@@ -235,23 +230,3 @@ def mix_inference(
md += '\n'
prev = curr
return md
if __name__ == '__main__':
img_path = "/Users/Leehy/Code/TexTeller/test3.png"
# latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx")
latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml")
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = (det_model, det_processor, rec_model, rec_processor)
texteller = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
latex_rec_models = (texteller, tokenizer)
res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models)
print(res)
pause = 1