Files
TexTeller/texteller/cli/commands/launch/server.py
yoge ba0968b2da
Some checks failed
Sphinx: Render docs / build (push) Has been cancelled
Python Linting / lint (push) Has been cancelled
Run Tests with Pytest / test (push) Has been cancelled
feat: add dockerfile
2025-12-15 22:31:13 +08:00

103 lines
3.3 KiB
Python

import numpy as np
import cv2
import base64
import requests
from io import BytesIO
from starlette.requests import Request
from starlette.responses import JSONResponse
from ray import serve
from ray.serve.handle import DeploymentHandle
from texteller.api import load_model, load_tokenizer, img2latex
from texteller.utils import get_device
from texteller.globals import Globals
from typing import Literal
@serve.deployment(
num_replicas=Globals().num_replicas,
ray_actor_options={
"num_cpus": Globals().ncpu_per_replica,
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
},
)
class TexTellerServer:
def __init__(
self,
checkpoint_dir: str,
tokenizer_dir: str,
use_onnx: bool = False,
out_format: Literal["latex", "katex"] = "katex",
keep_style: bool = False,
num_beams: int = 1,
) -> None:
self.model = load_model(
model_dir=checkpoint_dir,
use_onnx=use_onnx,
)
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
self.num_beams = num_beams
self.out_format = out_format
self.keep_style = keep_style
if not use_onnx:
self.model = self.model.to(get_device())
def predict(self, image_nparray: np.ndarray) -> str:
return img2latex(
model=self.model,
tokenizer=self.tokenizer,
images=[image_nparray],
device=get_device(),
out_format=self.out_format,
keep_style=self.keep_style,
num_beams=self.num_beams,
)[0]
@serve.deployment()
class Ingress:
def __init__(self, rec_server: DeploymentHandle) -> None:
self.texteller_server = rec_server
async def __call__(self, request: Request):
try:
# Parse JSON body
body = await request.json()
# Get image data from either base64 or URL
if "image_base64" in body:
# Decode base64 image
image_data = body["image_base64"]
# Remove data URL prefix if present (e.g., "data:image/png;base64,")
if "," in image_data:
image_data = image_data.split(",", 1)[1]
img_bytes = base64.b64decode(image_data)
img_nparray = np.frombuffer(img_bytes, np.uint8)
elif "image_url" in body:
# Download image from URL
image_url = body["image_url"]
response = requests.get(image_url, timeout=30)
response.raise_for_status()
img_bytes = response.content
img_nparray = np.frombuffer(img_bytes, np.uint8)
else:
return JSONResponse({"error": "Either 'image_base64' or 'image_url' must be provided"}, status_code=400)
# Decode and convert image
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
if img_nparray is None:
return JSONResponse({"error": "Failed to decode image"}, status_code=400)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
# Get prediction
pred = await self.texteller_server.predict.remote(img_nparray)
return JSONResponse({"result": pred})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)