import numpy as np import cv2 import base64 import requests from io import BytesIO from starlette.requests import Request from starlette.responses import JSONResponse from ray import serve from ray.serve.handle import DeploymentHandle from texteller.api import load_model, load_tokenizer, img2latex from texteller.utils import get_device from texteller.globals import Globals from typing import Literal @serve.deployment( num_replicas=Globals().num_replicas, ray_actor_options={ "num_cpus": Globals().ncpu_per_replica, "num_gpus": Globals().ngpu_per_replica * 1.0 / 2, }, ) class TexTellerServer: def __init__( self, checkpoint_dir: str, tokenizer_dir: str, use_onnx: bool = False, out_format: Literal["latex", "katex"] = "katex", keep_style: bool = False, num_beams: int = 1, ) -> None: self.model = load_model( model_dir=checkpoint_dir, use_onnx=use_onnx, ) self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir) self.num_beams = num_beams self.out_format = out_format self.keep_style = keep_style if not use_onnx: self.model = self.model.to(get_device()) def predict(self, image_nparray: np.ndarray) -> str: return img2latex( model=self.model, tokenizer=self.tokenizer, images=[image_nparray], device=get_device(), out_format=self.out_format, keep_style=self.keep_style, num_beams=self.num_beams, )[0] @serve.deployment() class Ingress: def __init__(self, rec_server: DeploymentHandle) -> None: self.texteller_server = rec_server async def __call__(self, request: Request): try: # Parse JSON body body = await request.json() # Get image data from either base64 or URL if "image_base64" in body: # Decode base64 image image_data = body["image_base64"] # Remove data URL prefix if present (e.g., "data:image/png;base64,") if "," in image_data: image_data = image_data.split(",", 1)[1] img_bytes = base64.b64decode(image_data) img_nparray = np.frombuffer(img_bytes, np.uint8) elif "image_url" in body: # Download image from URL image_url = body["image_url"] response = requests.get(image_url, timeout=30) response.raise_for_status() img_bytes = response.content img_nparray = np.frombuffer(img_bytes, np.uint8) else: return JSONResponse({"error": "Either 'image_base64' or 'image_url' must be provided"}, status_code=400) # Decode and convert image img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR) if img_nparray is None: return JSONResponse({"error": "Failed to decode image"}, status_code=400) img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) # Get prediction pred = await self.texteller_server.predict.remote(img_nparray) return JSONResponse({"result": pred}) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500)