merge dev后调整了项目结构
This commit is contained in:
@@ -3,24 +3,17 @@ import heapq
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
from surya.ocr import run_ocr
|
||||
from surya.detection import batch_text_detection
|
||||
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
|
||||
from surya.recognition import batch_recognition
|
||||
from surya.model.detection import segformer
|
||||
from surya.model.recognition.model import load_model
|
||||
from surya.model.recognition.processor import load_processor
|
||||
|
||||
from ..det_model.inference import PredictConfig
|
||||
from surya.detection import batch_text_detection
|
||||
from surya.input.processing import slice_polys_from_image
|
||||
from surya.recognition import batch_recognition
|
||||
|
||||
from ..det_model.inference import predict as latex_det_predict
|
||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||
|
||||
from ..ocr_model.model.TexTeller import TexTeller
|
||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
||||
from ..ocr_model.utils.to_katex import to_katex
|
||||
|
||||
@@ -59,11 +52,10 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
|
||||
######## debug #########
|
||||
# log results
|
||||
for idx, bbox in enumerate(bboxes):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
|
||||
######## debug ###########
|
||||
|
||||
assert len(bboxes) > 1
|
||||
|
||||
@@ -125,11 +117,12 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
|
||||
assert False
|
||||
res.append(candidate)
|
||||
res.append(curr)
|
||||
######## debug #########
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(res):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
|
||||
######## debug ###########
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -156,14 +149,17 @@ def mix_inference(
|
||||
|
||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
det_model, det_processor, rec_model, rec_processor = lang_ocr_models
|
||||
images = [Image.fromarray(masked_img)]
|
||||
det_prediction = batch_text_detection(images, det_model, det_processor)[0]
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
|
||||
|
||||
lang = [language]
|
||||
@@ -222,7 +218,6 @@ def mix_inference(
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
# prev = bboxes[0]
|
||||
for curr in bboxes:
|
||||
if not prev.same_row(curr):
|
||||
md += "\n"
|
||||
@@ -235,23 +230,3 @@ def mix_inference(
|
||||
md += '\n'
|
||||
prev = curr
|
||||
return md
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img_path = "/Users/Leehy/Code/TexTeller/test3.png"
|
||||
|
||||
# latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx")
|
||||
latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml")
|
||||
|
||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
||||
rec_model, rec_processor = load_model(), load_processor()
|
||||
lang_ocr_models = (det_model, det_processor, rec_model, rec_processor)
|
||||
|
||||
texteller = TexTeller.from_pretrained()
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
latex_rec_models = (texteller, tokenizer)
|
||||
|
||||
res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models)
|
||||
print(res)
|
||||
pause = 1
|
||||
|
||||
Reference in New Issue
Block a user