Update server.py
1. Change the default host address to 0.0.0.0. 2. Convert the output to KaTeX.
This commit is contained in:
@@ -9,6 +9,7 @@ from ray.serve.handle import DeploymentHandle
|
|||||||
|
|
||||||
from models.ocr_model.utils.inference import inference
|
from models.ocr_model.utils.inference import inference
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
|
from models.ocr_model.utils.to_katex import to_katex
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -27,8 +28,8 @@ parser.add_argument('--inference-mode', type=str, default='cpu')
|
|||||||
parser.add_argument('--num_beams', type=int, default=1)
|
parser.add_argument('--num_beams', type=int, default=1)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.ngpu_per_replica > 0 and not args.use_cuda:
|
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
|
||||||
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
|
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(
|
@serve.deployment(
|
||||||
@@ -54,10 +55,10 @@ class TexTellerServer:
|
|||||||
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
||||||
|
|
||||||
def predict(self, image_nparray) -> str:
|
def predict(self, image_nparray) -> str:
|
||||||
return inference(
|
return to_katex(inference(
|
||||||
self.model, self.tokenizer, [image_nparray],
|
self.model, self.tokenizer, [image_nparray],
|
||||||
accelerator=self.inf_mode, num_beams=self.num_beams
|
accelerator=self.inf_mode, num_beams=self.num_beams
|
||||||
)[0]
|
)[0])
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment()
|
@serve.deployment()
|
||||||
@@ -80,7 +81,7 @@ if __name__ == '__main__':
|
|||||||
ckpt_dir = args.checkpoint_dir
|
ckpt_dir = args.checkpoint_dir
|
||||||
tknz_dir = args.tokenizer_dir
|
tknz_dir = args.tokenizer_dir
|
||||||
|
|
||||||
serve.start(http_options={"port": args.server_port})
|
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
||||||
texteller_server = TexTellerServer.bind(
|
texteller_server = TexTellerServer.bind(
|
||||||
ckpt_dir, tknz_dir,
|
ckpt_dir, tknz_dir,
|
||||||
inf_mode=args.inference_mode,
|
inf_mode=args.inference_mode,
|
||||||
|
|||||||
Reference in New Issue
Block a user