From d8659cd3a9a5b5b4c17173e98be21f69e1894eeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Mon, 17 Jun 2024 21:03:08 +0800 Subject: [PATCH] Add formula detection service --- src/client_demo.py | 6 +++-- src/server.py | 59 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/src/client_demo.py b/src/client_demo.py index 2b9f033..8e28ebf 100644 --- a/src/client_demo.py +++ b/src/client_demo.py @@ -1,10 +1,12 @@ import requests -url = "http://127.0.0.1:8000/predict" +rec_server_url = "http://127.0.0.1:8000/frec" +det_server_url = "http://127.0.0.1:8000/fdet" img_path = "/your/image/path/" with open(img_path, 'rb') as img: files = {'img': img} - response = requests.post(url, files=files) + response = requests.post(det_server_url, files=files) + # response = requests.post(rec_server_url, files=files) print(response.text) diff --git a/src/server.py b/src/server.py index 372f3e4..b5e83ac 100644 --- a/src/server.py +++ b/src/server.py @@ -1,14 +1,19 @@ import argparse +import tempfile import time import numpy as np import cv2 +from pathlib import Path from starlette.requests import Request from ray import serve from ray.serve.handle import DeploymentHandle +from onnxruntime import InferenceSession -from models.ocr_model.utils.inference import inference +from models.ocr_model.utils.inference import inference as rec_inference +from models.det_model.inference import predict as det_inference from models.ocr_model.model.TexTeller import TexTeller +from models.det_model.inference import PredictConfig from models.ocr_model.utils.to_katex import to_katex @@ -39,7 +44,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda': "num_gpus": args.ngpu_per_replica } ) -class TexTellerServer: +class TexTellerRecServer: def __init__( self, checkpoint_path: str, @@ -55,26 +60,56 @@ class TexTellerServer: self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model def predict(self, image_nparray) -> str: - return to_katex(inference( + return to_katex(rec_inference( self.model, self.tokenizer, [image_nparray], accelerator=self.inf_mode, num_beams=self.num_beams )[0]) +@serve.deployment(num_replicas=args.num_replicas) +class TexTellerDetServer: + def __init__( + self + ): + self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") + self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") + + async def predict(self, image_nparray) -> str: + with tempfile.TemporaryDirectory() as temp_dir: + img_path = f"{temp_dir}/temp_image.jpg" + cv2.imwrite(img_path, image_nparray) + + latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config) + return latex_bboxes + + @serve.deployment() class Ingress: - def __init__(self, texteller_server: DeploymentHandle) -> None: - self.texteller_server = texteller_server + def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None: + self.det_server = det_server + self.texteller_server = rec_server async def __call__(self, request: Request) -> str: + request_path = request.url.path form = await request.form() img_rb = await form['img'].read() img_nparray = np.frombuffer(img_rb, np.uint8) img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR) img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) - pred = await self.texteller_server.predict.remote(img_nparray) - return pred + + if request_path.startswith("/fdet"): + if self.det_server == None: + return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found." + pred = await self.det_server.predict.remote(img_nparray) + return pred + + elif request_path.startswith("/frec"): + pred = await self.texteller_server.predict.remote(img_nparray) + return pred + + else: + return "[ERROR] Invalid request path" if __name__ == '__main__': @@ -82,14 +117,18 @@ if __name__ == '__main__': tknz_dir = args.tokenizer_dir serve.start(http_options={"host": "0.0.0.0", "port": args.server_port}) - texteller_server = TexTellerServer.bind( + rec_server = TexTellerRecServer.bind( ckpt_dir, tknz_dir, inf_mode=args.inference_mode, num_beams=args.num_beams ) - ingress = Ingress.bind(texteller_server) + det_server = None + if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists(): + det_server = TexTellerDetServer.bind() + ingress = Ingress.bind(det_server, rec_server) - ingress_handle = serve.run(ingress, route_prefix="/predict") + # ingress_handle = serve.run(ingress, route_prefix="/predict") + ingress_handle = serve.run(ingress, route_prefix="/") while True: time.sleep(1)