Add formula detection service

This commit is contained in:
三洋三洋
2024-06-17 21:03:08 +08:00
parent c849728ee7
commit d8659cd3a9
2 changed files with 53 additions and 12 deletions

View File

@@ -1,10 +1,12 @@
import requests
url = "http://127.0.0.1:8000/predict"
rec_server_url = "http://127.0.0.1:8000/frec"
det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(url, files=files)
response = requests.post(det_server_url, files=files)
# response = requests.post(rec_server_url, files=files)
print(response.text)

View File

@@ -1,14 +1,19 @@
import argparse
import tempfile
import time
import numpy as np
import cv2
from pathlib import Path
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from onnxruntime import InferenceSession
from models.ocr_model.utils.inference import inference
from models.ocr_model.utils.inference import inference as rec_inference
from models.det_model.inference import predict as det_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from models.ocr_model.utils.to_katex import to_katex
@@ -39,7 +44,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
class TexTellerRecServer:
def __init__(
self,
checkpoint_path: str,
@@ -55,26 +60,56 @@ class TexTellerServer:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return to_katex(inference(
return to_katex(rec_inference(
self.model, self.tokenizer, [image_nparray],
accelerator=self.inf_mode, num_beams=self.num_beams
)[0])
@serve.deployment(num_replicas=args.num_replicas)
class TexTellerDetServer:
def __init__(
self
):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir:
img_path = f"{temp_dir}/temp_image.jpg"
cv2.imwrite(img_path, image_nparray)
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
return latex_bboxes
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
self.det_server = det_server
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
request_path = request.url.path
form = await request.form()
img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
pred = await self.texteller_server.predict.remote(img_nparray)
return pred
if request_path.startswith("/fdet"):
if self.det_server == None:
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
pred = await self.det_server.predict.remote(img_nparray)
return pred
elif request_path.startswith("/frec"):
pred = await self.texteller_server.predict.remote(img_nparray)
return pred
else:
return "[ERROR] Invalid request path"
if __name__ == '__main__':
@@ -82,14 +117,18 @@ if __name__ == '__main__':
tknz_dir = args.tokenizer_dir
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
texteller_server = TexTellerServer.bind(
rec_server = TexTellerRecServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,
num_beams=args.num_beams
)
ingress = Ingress.bind(texteller_server)
det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind()
ingress = Ingress.bind(det_server, rec_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
# ingress_handle = serve.run(ingress, route_prefix="/predict")
ingress_handle = serve.run(ingress, route_prefix="/")
while True:
time.sleep(1)