2024-06-22 21:51:51 +08:00
|
|
|
import sys
|
2024-02-11 08:06:50 +00:00
|
|
|
import argparse
|
2024-06-17 21:03:08 +08:00
|
|
|
import tempfile
|
2024-02-11 08:06:50 +00:00
|
|
|
import time
|
2024-02-27 07:13:36 +00:00
|
|
|
import numpy as np
|
|
|
|
|
import cv2
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-06-17 21:03:08 +08:00
|
|
|
from pathlib import Path
|
2024-02-11 08:06:50 +00:00
|
|
|
from starlette.requests import Request
|
|
|
|
|
from ray import serve
|
|
|
|
|
from ray.serve.handle import DeploymentHandle
|
2024-06-17 21:03:08 +08:00
|
|
|
from onnxruntime import InferenceSession
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-06-17 21:03:08 +08:00
|
|
|
from models.ocr_model.utils.inference import inference as rec_inference
|
|
|
|
|
from models.det_model.inference import predict as det_inference
|
2024-02-11 08:06:50 +00:00
|
|
|
from models.ocr_model.model.TexTeller import TexTeller
|
2024-06-17 21:03:08 +08:00
|
|
|
from models.det_model.inference import PredictConfig
|
2024-06-07 11:47:53 +00:00
|
|
|
from models.ocr_model.utils.to_katex import to_katex
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
|
|
|
|
|
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
|
|
|
|
|
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
|
|
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'-ckpt', '--checkpoint_dir', type=str
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'-tknz', '--tokenizer_dir', type=str
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument('-port', '--server_port', type=int, default=8000)
|
|
|
|
|
parser.add_argument('--num_replicas', type=int, default=1)
|
|
|
|
|
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
|
|
|
|
|
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
|
|
|
|
|
|
2024-04-17 09:12:07 +00:00
|
|
|
parser.add_argument('--inference-mode', type=str, default='cpu')
|
|
|
|
|
parser.add_argument('--num_beams', type=int, default=1)
|
2024-06-22 21:51:51 +08:00
|
|
|
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
2024-06-07 11:47:53 +00:00
|
|
|
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
|
|
|
|
|
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@serve.deployment(
|
|
|
|
|
num_replicas=args.num_replicas,
|
|
|
|
|
ray_actor_options={
|
|
|
|
|
"num_cpus": args.ncpu_per_replica,
|
2024-06-22 21:51:51 +08:00
|
|
|
"num_gpus": args.ngpu_per_replica * 1.0 / 2
|
2024-02-11 08:06:50 +00:00
|
|
|
}
|
|
|
|
|
)
|
2024-06-17 21:03:08 +08:00
|
|
|
class TexTellerRecServer:
|
2024-02-11 08:06:50 +00:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
checkpoint_path: str,
|
|
|
|
|
tokenizer_path: str,
|
2024-04-17 09:12:07 +00:00
|
|
|
inf_mode: str = 'cpu',
|
2024-06-22 21:51:51 +08:00
|
|
|
use_onnx: bool = False,
|
2024-04-17 09:12:07 +00:00
|
|
|
num_beams: int = 1
|
2024-02-11 08:06:50 +00:00
|
|
|
) -> None:
|
2024-06-22 21:51:51 +08:00
|
|
|
self.model = TexTeller.from_pretrained(checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode)
|
2024-02-11 08:06:50 +00:00
|
|
|
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
|
2024-04-17 09:12:07 +00:00
|
|
|
self.inf_mode = inf_mode
|
|
|
|
|
self.num_beams = num_beams
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
if not use_onnx:
|
|
|
|
|
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-02-27 07:13:36 +00:00
|
|
|
def predict(self, image_nparray) -> str:
|
2024-06-17 21:03:08 +08:00
|
|
|
return to_katex(rec_inference(
|
2024-04-17 09:12:07 +00:00
|
|
|
self.model, self.tokenizer, [image_nparray],
|
2024-04-21 00:05:14 +08:00
|
|
|
accelerator=self.inf_mode, num_beams=self.num_beams
|
2024-06-07 11:47:53 +00:00
|
|
|
)[0])
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
@serve.deployment(
|
|
|
|
|
num_replicas=args.num_replicas,
|
|
|
|
|
ray_actor_options={
|
|
|
|
|
"num_cpus": args.ncpu_per_replica,
|
|
|
|
|
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
|
|
|
|
|
"runtime_env": {
|
|
|
|
|
"env_vars": {
|
|
|
|
|
"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
)
|
2024-06-17 21:03:08 +08:00
|
|
|
class TexTellerDetServer:
|
|
|
|
|
def __init__(
|
2024-06-22 21:51:51 +08:00
|
|
|
self,
|
|
|
|
|
inf_mode='cpu'
|
2024-06-17 21:03:08 +08:00
|
|
|
):
|
|
|
|
|
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
2024-06-22 21:51:51 +08:00
|
|
|
self.latex_det_model = InferenceSession(
|
|
|
|
|
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
|
|
|
|
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider']
|
|
|
|
|
)
|
2024-06-17 21:03:08 +08:00
|
|
|
|
|
|
|
|
async def predict(self, image_nparray) -> str:
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
|
|
|
img_path = f"{temp_dir}/temp_image.jpg"
|
|
|
|
|
cv2.imwrite(img_path, image_nparray)
|
|
|
|
|
|
|
|
|
|
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
|
|
|
|
|
return latex_bboxes
|
|
|
|
|
|
|
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
@serve.deployment()
|
|
|
|
|
class Ingress:
|
2024-06-17 21:03:08 +08:00
|
|
|
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
|
|
|
|
|
self.det_server = det_server
|
|
|
|
|
self.texteller_server = rec_server
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> str:
|
2024-06-17 21:03:08 +08:00
|
|
|
request_path = request.url.path
|
2024-02-27 07:13:36 +00:00
|
|
|
form = await request.form()
|
|
|
|
|
img_rb = await form['img'].read()
|
|
|
|
|
|
|
|
|
|
img_nparray = np.frombuffer(img_rb, np.uint8)
|
2024-02-27 07:44:35 +00:00
|
|
|
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
2024-02-27 07:13:36 +00:00
|
|
|
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
2024-06-17 21:03:08 +08:00
|
|
|
|
|
|
|
|
if request_path.startswith("/fdet"):
|
|
|
|
|
if self.det_server == None:
|
|
|
|
|
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
|
|
|
|
|
pred = await self.det_server.predict.remote(img_nparray)
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
|
|
elif request_path.startswith("/frec"):
|
|
|
|
|
pred = await self.texteller_server.predict.remote(img_nparray)
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
return "[ERROR] Invalid request path"
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
ckpt_dir = args.checkpoint_dir
|
|
|
|
|
tknz_dir = args.tokenizer_dir
|
|
|
|
|
|
2024-06-07 11:47:53 +00:00
|
|
|
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
2024-06-17 21:03:08 +08:00
|
|
|
rec_server = TexTellerRecServer.bind(
|
2024-04-17 09:12:07 +00:00
|
|
|
ckpt_dir, tknz_dir,
|
|
|
|
|
inf_mode=args.inference_mode,
|
2024-06-22 21:51:51 +08:00
|
|
|
use_onnx=args.onnx,
|
2024-04-17 09:12:07 +00:00
|
|
|
num_beams=args.num_beams
|
|
|
|
|
)
|
2024-06-17 21:03:08 +08:00
|
|
|
det_server = None
|
|
|
|
|
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
|
2024-06-22 21:51:51 +08:00
|
|
|
det_server = TexTellerDetServer.bind(args.inference_mode)
|
2024-06-17 21:03:08 +08:00
|
|
|
ingress = Ingress.bind(det_server, rec_server)
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-06-17 21:03:08 +08:00
|
|
|
# ingress_handle = serve.run(ingress, route_prefix="/predict")
|
|
|
|
|
ingress_handle = serve.run(ingress, route_prefix="/")
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
time.sleep(1)
|