Files
TexTeller/src/server.py

95 lines
2.8 KiB
Python
Raw Normal View History

2024-02-11 08:06:50 +00:00
import argparse
import time
import numpy as np
import cv2
2024-02-11 08:06:50 +00:00
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
parser = argparse.ArgumentParser()
parser.add_argument(
'-ckpt', '--checkpoint_dir', type=str
)
parser.add_argument(
'-tknz', '--tokenizer_dir', type=str
)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1)
2024-02-11 08:06:50 +00:00
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.use_cuda:
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
inf_mode: str = 'cpu',
num_beams: int = 1
2024-02-11 08:06:50 +00:00
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.inf_mode = inf_mode
self.num_beams = num_beams
2024-02-11 08:06:50 +00:00
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
2024-02-11 08:06:50 +00:00
def predict(self, image_nparray) -> str:
return inference(
self.model, self.tokenizer, [image_nparray],
inf_mode=self.inf_mode, num_beams=self.num_beams
)[0]
2024-02-11 08:06:50 +00:00
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
async def __call__(self, request: Request) -> str:
form = await request.form()
img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
2024-02-27 07:44:35 +00:00
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
pred = await self.texteller_server.predict.remote(img_nparray)
2024-02-11 08:06:50 +00:00
return pred
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port})
texteller_server = TexTellerServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,
num_beams=args.num_beams
)
2024-02-11 08:06:50 +00:00
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)