[refactor] Init
This commit is contained in:
106
texteller/cli/commands/launch/__init__.py
Normal file
106
texteller/cli/commands/launch/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
CLI commands for launching server.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import click
|
||||
from ray import serve
|
||||
|
||||
from texteller.globals import Globals
|
||||
from texteller.utils import get_device
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"-ckpt",
|
||||
"--checkpoint_dir",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the checkpoint directory, if not provided, will use model from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"-tknz",
|
||||
"--tokenizer_dir",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"-p",
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on",
|
||||
)
|
||||
@click.option(
|
||||
"--num-replicas",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of replicas to run the server on",
|
||||
)
|
||||
@click.option(
|
||||
"--ncpu-per-replica",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Number of CPUs per replica",
|
||||
)
|
||||
@click.option(
|
||||
"--ngpu-per-replica",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Number of GPUs per replica",
|
||||
)
|
||||
@click.option(
|
||||
"--num-beams",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of beams to use",
|
||||
)
|
||||
@click.option(
|
||||
"--use-onnx",
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use ONNX runtime",
|
||||
)
|
||||
def launch(
|
||||
checkpoint_dir,
|
||||
tokenizer_dir,
|
||||
port,
|
||||
num_replicas,
|
||||
ncpu_per_replica,
|
||||
ngpu_per_replica,
|
||||
num_beams,
|
||||
use_onnx,
|
||||
):
|
||||
"""Launch the api server"""
|
||||
device = get_device()
|
||||
if ngpu_per_replica > 0 and not device.type == "cuda":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error: --ngpu-per-replica > 0 but detected device is {device.type}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
Globals().num_replicas = num_replicas
|
||||
Globals().ncpu_per_replica = ncpu_per_replica
|
||||
Globals().ngpu_per_replica = ngpu_per_replica
|
||||
from texteller.cli.commands.launch.server import Ingress, TexTellerServer
|
||||
|
||||
serve.start(http_options={"host": "0.0.0.0", "port": port})
|
||||
rec_server = TexTellerServer.bind(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
tokenizer_dir=tokenizer_dir,
|
||||
use_onnx=use_onnx,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
ingress = Ingress.bind(rec_server)
|
||||
|
||||
serve.run(ingress, route_prefix="/predict")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
69
texteller/cli/commands/launch/server.py
Normal file
69
texteller/cli/commands/launch/server.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from starlette.requests import Request
|
||||
from ray import serve
|
||||
from ray.serve.handle import DeploymentHandle
|
||||
|
||||
from texteller.api import load_model, load_tokenizer, img2latex
|
||||
from texteller.utils import get_device
|
||||
from texteller.globals import Globals
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=Globals().num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": Globals().ncpu_per_replica,
|
||||
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
|
||||
},
|
||||
)
|
||||
class TexTellerServer:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
tokenizer_dir: str,
|
||||
use_onnx: bool = False,
|
||||
out_format: Literal["latex", "katex"] = "katex",
|
||||
keep_style: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
self.model = load_model(
|
||||
model_dir=checkpoint_dir,
|
||||
use_onnx=use_onnx,
|
||||
)
|
||||
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
|
||||
self.num_beams = num_beams
|
||||
self.out_format = out_format
|
||||
self.keep_style = keep_style
|
||||
|
||||
if not use_onnx:
|
||||
self.model = self.model.to(get_device())
|
||||
|
||||
def predict(self, image_nparray: np.ndarray) -> str:
|
||||
return img2latex(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
images=[image_nparray],
|
||||
device=get_device(),
|
||||
out_format=self.out_format,
|
||||
keep_style=self.keep_style,
|
||||
num_beams=self.num_beams,
|
||||
)[0]
|
||||
|
||||
|
||||
@serve.deployment()
|
||||
class Ingress:
|
||||
def __init__(self, rec_server: DeploymentHandle) -> None:
|
||||
self.texteller_server = rec_server
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
form = await request.form()
|
||||
img_rb = await form["img"].read()
|
||||
|
||||
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
||||
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||
|
||||
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||
return pred
|
||||
Reference in New Issue
Block a user