Update server.py

1. Change the default host address to 0.0.0.0.
2. Convert the output to KaTeX.
This commit is contained in:
三洋三洋
2024-06-07 11:47:53 +00:00
parent aa14674097
commit 624f9531b4

View File

@@ -9,6 +9,7 @@ from ray.serve.handle import DeploymentHandle
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.to_katex import to_katex
parser = argparse.ArgumentParser()
@@ -27,8 +28,8 @@ parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1)
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.use_cuda:
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
@serve.deployment(
@@ -54,10 +55,10 @@ class TexTellerServer:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return inference(
return to_katex(inference(
self.model, self.tokenizer, [image_nparray],
accelerator=self.inf_mode, num_beams=self.num_beams
)[0]
)[0])
@serve.deployment()
@@ -80,7 +81,7 @@ if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port})
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
texteller_server = TexTellerServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,