Merge pull request #58 from OleehyO/pre_release

Add formula detection service
This commit is contained in:
OleehyO
2024-06-17 21:26:35 +08:00
committed by GitHub
2 changed files with 53 additions and 12 deletions

View File

@@ -1,10 +1,12 @@
import requests
url = "http://127.0.0.1:8000/predict"
rec_server_url = "http://127.0.0.1:8000/frec"
det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(url, files=files)
response = requests.post(det_server_url, files=files)
# response = requests.post(rec_server_url, files=files)
print(response.text)

View File

@@ -1,14 +1,19 @@
import argparse
import tempfile
import time
import numpy as np
import cv2
from pathlib import Path
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from onnxruntime import InferenceSession
from models.ocr_model.utils.inference import inference
from models.ocr_model.utils.inference import inference as rec_inference
from models.det_model.inference import predict as det_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from models.ocr_model.utils.to_katex import to_katex
@@ -39,7 +44,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
class TexTellerRecServer:
def __init__(
self,
checkpoint_path: str,
@@ -55,41 +60,75 @@ class TexTellerServer:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return to_katex(inference(
return to_katex(rec_inference(
self.model, self.tokenizer, [image_nparray],
accelerator=self.inf_mode, num_beams=self.num_beams
)[0])
@serve.deployment(num_replicas=args.num_replicas)
class TexTellerDetServer:
def __init__(
self
):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir:
img_path = f"{temp_dir}/temp_image.jpg"
cv2.imwrite(img_path, image_nparray)
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
return latex_bboxes
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
self.det_server = det_server
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
request_path = request.url.path
form = await request.form()
img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
if request_path.startswith("/fdet"):
if self.det_server == None:
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
pred = await self.det_server.predict.remote(img_nparray)
return pred
elif request_path.startswith("/frec"):
pred = await self.texteller_server.predict.remote(img_nparray)
return pred
else:
return "[ERROR] Invalid request path"
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
texteller_server = TexTellerServer.bind(
rec_server = TexTellerRecServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,
num_beams=args.num_beams
)
ingress = Ingress.bind(texteller_server)
det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind()
ingress = Ingress.bind(det_server, rec_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
# ingress_handle = serve.run(ingress, route_prefix="/predict")
ingress_handle = serve.run(ingress, route_prefix="/")
while True:
time.sleep(1)