Add formula detection service

This commit is contained in:
三洋三洋
2024-06-17 21:03:08 +08:00
parent c849728ee7
commit d8659cd3a9
2 changed files with 53 additions and 12 deletions

View File

@@ -1,10 +1,12 @@
import requests import requests
url = "http://127.0.0.1:8000/predict" rec_server_url = "http://127.0.0.1:8000/frec"
det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/" img_path = "/your/image/path/"
with open(img_path, 'rb') as img: with open(img_path, 'rb') as img:
files = {'img': img} files = {'img': img}
response = requests.post(url, files=files) response = requests.post(det_server_url, files=files)
# response = requests.post(rec_server_url, files=files)
print(response.text) print(response.text)

View File

@@ -1,14 +1,19 @@
import argparse import argparse
import tempfile
import time import time
import numpy as np import numpy as np
import cv2 import cv2
from pathlib import Path
from starlette.requests import Request from starlette.requests import Request
from ray import serve from ray import serve
from ray.serve.handle import DeploymentHandle from ray.serve.handle import DeploymentHandle
from onnxruntime import InferenceSession
from models.ocr_model.utils.inference import inference from models.ocr_model.utils.inference import inference as rec_inference
from models.det_model.inference import predict as det_inference
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from models.ocr_model.utils.to_katex import to_katex from models.ocr_model.utils.to_katex import to_katex
@@ -39,7 +44,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
"num_gpus": args.ngpu_per_replica "num_gpus": args.ngpu_per_replica
} }
) )
class TexTellerServer: class TexTellerRecServer:
def __init__( def __init__(
self, self,
checkpoint_path: str, checkpoint_path: str,
@@ -55,41 +60,75 @@ class TexTellerServer:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str: def predict(self, image_nparray) -> str:
return to_katex(inference( return to_katex(rec_inference(
self.model, self.tokenizer, [image_nparray], self.model, self.tokenizer, [image_nparray],
accelerator=self.inf_mode, num_beams=self.num_beams accelerator=self.inf_mode, num_beams=self.num_beams
)[0]) )[0])
@serve.deployment(num_replicas=args.num_replicas)
class TexTellerDetServer:
def __init__(
self
):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir:
img_path = f"{temp_dir}/temp_image.jpg"
cv2.imwrite(img_path, image_nparray)
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
return latex_bboxes
@serve.deployment() @serve.deployment()
class Ingress: class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None: def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server self.det_server = det_server
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
request_path = request.url.path
form = await request.form() form = await request.form()
img_rb = await form['img'].read() img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8) img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR) img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
if request_path.startswith("/fdet"):
if self.det_server == None:
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
pred = await self.det_server.predict.remote(img_nparray)
return pred
elif request_path.startswith("/frec"):
pred = await self.texteller_server.predict.remote(img_nparray) pred = await self.texteller_server.predict.remote(img_nparray)
return pred return pred
else:
return "[ERROR] Invalid request path"
if __name__ == '__main__': if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir tknz_dir = args.tokenizer_dir
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port}) serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
texteller_server = TexTellerServer.bind( rec_server = TexTellerRecServer.bind(
ckpt_dir, tknz_dir, ckpt_dir, tknz_dir,
inf_mode=args.inference_mode, inf_mode=args.inference_mode,
num_beams=args.num_beams num_beams=args.num_beams
) )
ingress = Ingress.bind(texteller_server) det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind()
ingress = Ingress.bind(det_server, rec_server)
ingress_handle = serve.run(ingress, route_prefix="/predict") # ingress_handle = serve.run(ingress, route_prefix="/predict")
ingress_handle = serve.run(ingress, route_prefix="/")
while True: while True:
time.sleep(1) time.sleep(1)