Add formula detection service
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
url = "http://127.0.0.1:8000/predict"
|
rec_server_url = "http://127.0.0.1:8000/frec"
|
||||||
|
det_server_url = "http://127.0.0.1:8000/fdet"
|
||||||
|
|
||||||
img_path = "/your/image/path/"
|
img_path = "/your/image/path/"
|
||||||
with open(img_path, 'rb') as img:
|
with open(img_path, 'rb') as img:
|
||||||
files = {'img': img}
|
files = {'img': img}
|
||||||
response = requests.post(url, files=files)
|
response = requests.post(det_server_url, files=files)
|
||||||
|
# response = requests.post(rec_server_url, files=files)
|
||||||
|
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray.serve.handle import DeploymentHandle
|
from ray.serve.handle import DeploymentHandle
|
||||||
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
from models.ocr_model.utils.inference import inference
|
from models.ocr_model.utils.inference import inference as rec_inference
|
||||||
|
from models.det_model.inference import predict as det_inference
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
|
from models.det_model.inference import PredictConfig
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
from models.ocr_model.utils.to_katex import to_katex
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +44,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
|
|||||||
"num_gpus": args.ngpu_per_replica
|
"num_gpus": args.ngpu_per_replica
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
class TexTellerServer:
|
class TexTellerRecServer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
@@ -55,26 +60,56 @@ class TexTellerServer:
|
|||||||
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
||||||
|
|
||||||
def predict(self, image_nparray) -> str:
|
def predict(self, image_nparray) -> str:
|
||||||
return to_katex(inference(
|
return to_katex(rec_inference(
|
||||||
self.model, self.tokenizer, [image_nparray],
|
self.model, self.tokenizer, [image_nparray],
|
||||||
accelerator=self.inf_mode, num_beams=self.num_beams
|
accelerator=self.inf_mode, num_beams=self.num_beams
|
||||||
)[0])
|
)[0])
|
||||||
|
|
||||||
|
|
||||||
|
@serve.deployment(num_replicas=args.num_replicas)
|
||||||
|
class TexTellerDetServer:
|
||||||
|
def __init__(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||||
|
self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||||
|
|
||||||
|
async def predict(self, image_nparray) -> str:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
img_path = f"{temp_dir}/temp_image.jpg"
|
||||||
|
cv2.imwrite(img_path, image_nparray)
|
||||||
|
|
||||||
|
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
|
||||||
|
return latex_bboxes
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment()
|
@serve.deployment()
|
||||||
class Ingress:
|
class Ingress:
|
||||||
def __init__(self, texteller_server: DeploymentHandle) -> None:
|
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
|
||||||
self.texteller_server = texteller_server
|
self.det_server = det_server
|
||||||
|
self.texteller_server = rec_server
|
||||||
|
|
||||||
async def __call__(self, request: Request) -> str:
|
async def __call__(self, request: Request) -> str:
|
||||||
|
request_path = request.url.path
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
img_rb = await form['img'].read()
|
img_rb = await form['img'].read()
|
||||||
|
|
||||||
img_nparray = np.frombuffer(img_rb, np.uint8)
|
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||||
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
||||||
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||||
pred = await self.texteller_server.predict.remote(img_nparray)
|
|
||||||
return pred
|
if request_path.startswith("/fdet"):
|
||||||
|
if self.det_server == None:
|
||||||
|
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
|
||||||
|
pred = await self.det_server.predict.remote(img_nparray)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
elif request_path.startswith("/frec"):
|
||||||
|
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
else:
|
||||||
|
return "[ERROR] Invalid request path"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -82,14 +117,18 @@ if __name__ == '__main__':
|
|||||||
tknz_dir = args.tokenizer_dir
|
tknz_dir = args.tokenizer_dir
|
||||||
|
|
||||||
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
||||||
texteller_server = TexTellerServer.bind(
|
rec_server = TexTellerRecServer.bind(
|
||||||
ckpt_dir, tknz_dir,
|
ckpt_dir, tknz_dir,
|
||||||
inf_mode=args.inference_mode,
|
inf_mode=args.inference_mode,
|
||||||
num_beams=args.num_beams
|
num_beams=args.num_beams
|
||||||
)
|
)
|
||||||
ingress = Ingress.bind(texteller_server)
|
det_server = None
|
||||||
|
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
|
||||||
|
det_server = TexTellerDetServer.bind()
|
||||||
|
ingress = Ingress.bind(det_server, rec_server)
|
||||||
|
|
||||||
ingress_handle = serve.run(ingress, route_prefix="/predict")
|
# ingress_handle = serve.run(ingress, route_prefix="/predict")
|
||||||
|
ingress_handle = serve.run(ingress, route_prefix="/")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user