From a1c2b5b1ef294caca7c1c53085e4d7981d25ee8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Fri, 7 Jun 2024 11:47:53 +0000 Subject: [PATCH] Update server.py 1. Change the default host address to 0.0.0.0. 2. Convert the output to KaTeX. --- src/server.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/server.py b/src/server.py index 52e0068..372f3e4 100644 --- a/src/server.py +++ b/src/server.py @@ -9,6 +9,7 @@ from ray.serve.handle import DeploymentHandle from models.ocr_model.utils.inference import inference from models.ocr_model.model.TexTeller import TexTeller +from models.ocr_model.utils.to_katex import to_katex parser = argparse.ArgumentParser() @@ -27,8 +28,8 @@ parser.add_argument('--inference-mode', type=str, default='cpu') parser.add_argument('--num_beams', type=int, default=1) args = parser.parse_args() -if args.ngpu_per_replica > 0 and not args.use_cuda: - raise ValueError("use_cuda must be True if ngpu_per_replica > 0") +if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda': + raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0") @serve.deployment( @@ -54,10 +55,10 @@ class TexTellerServer: self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model def predict(self, image_nparray) -> str: - return inference( + return to_katex(inference( self.model, self.tokenizer, [image_nparray], accelerator=self.inf_mode, num_beams=self.num_beams - )[0] + )[0]) @serve.deployment() @@ -80,7 +81,7 @@ if __name__ == '__main__': ckpt_dir = args.checkpoint_dir tknz_dir = args.tokenizer_dir - serve.start(http_options={"port": args.server_port}) + serve.start(http_options={"host": "0.0.0.0", "port": args.server_port}) texteller_server = TexTellerServer.bind( ckpt_dir, tknz_dir, inf_mode=args.inference_mode,