2025-04-16 14:23:02 +00:00
|
|
|
import numpy as np
|
|
|
|
|
import cv2
|
2025-12-15 22:31:13 +08:00
|
|
|
import base64
|
|
|
|
|
import requests
|
|
|
|
|
from io import BytesIO
|
2025-04-16 14:23:02 +00:00
|
|
|
|
|
|
|
|
from starlette.requests import Request
|
2025-12-15 22:31:13 +08:00
|
|
|
from starlette.responses import JSONResponse
|
2025-04-16 14:23:02 +00:00
|
|
|
from ray import serve
|
|
|
|
|
from ray.serve.handle import DeploymentHandle
|
|
|
|
|
|
|
|
|
|
from texteller.api import load_model, load_tokenizer, img2latex
|
|
|
|
|
from texteller.utils import get_device
|
|
|
|
|
from texteller.globals import Globals
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@serve.deployment(
|
|
|
|
|
num_replicas=Globals().num_replicas,
|
|
|
|
|
ray_actor_options={
|
|
|
|
|
"num_cpus": Globals().ncpu_per_replica,
|
|
|
|
|
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
class TexTellerServer:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
checkpoint_dir: str,
|
|
|
|
|
tokenizer_dir: str,
|
|
|
|
|
use_onnx: bool = False,
|
|
|
|
|
out_format: Literal["latex", "katex"] = "katex",
|
|
|
|
|
keep_style: bool = False,
|
|
|
|
|
num_beams: int = 1,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.model = load_model(
|
|
|
|
|
model_dir=checkpoint_dir,
|
|
|
|
|
use_onnx=use_onnx,
|
|
|
|
|
)
|
|
|
|
|
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
|
|
|
|
|
self.num_beams = num_beams
|
|
|
|
|
self.out_format = out_format
|
|
|
|
|
self.keep_style = keep_style
|
|
|
|
|
|
|
|
|
|
if not use_onnx:
|
|
|
|
|
self.model = self.model.to(get_device())
|
|
|
|
|
|
|
|
|
|
def predict(self, image_nparray: np.ndarray) -> str:
|
|
|
|
|
return img2latex(
|
|
|
|
|
model=self.model,
|
|
|
|
|
tokenizer=self.tokenizer,
|
|
|
|
|
images=[image_nparray],
|
|
|
|
|
device=get_device(),
|
|
|
|
|
out_format=self.out_format,
|
|
|
|
|
keep_style=self.keep_style,
|
|
|
|
|
num_beams=self.num_beams,
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@serve.deployment()
|
|
|
|
|
class Ingress:
|
|
|
|
|
def __init__(self, rec_server: DeploymentHandle) -> None:
|
|
|
|
|
self.texteller_server = rec_server
|
|
|
|
|
|
2025-12-15 22:31:13 +08:00
|
|
|
async def __call__(self, request: Request):
|
|
|
|
|
try:
|
|
|
|
|
# Parse JSON body
|
|
|
|
|
body = await request.json()
|
2025-04-16 14:23:02 +00:00
|
|
|
|
2025-12-15 22:31:13 +08:00
|
|
|
# Get image data from either base64 or URL
|
|
|
|
|
if "image_base64" in body:
|
|
|
|
|
# Decode base64 image
|
|
|
|
|
image_data = body["image_base64"]
|
|
|
|
|
# Remove data URL prefix if present (e.g., "data:image/png;base64,")
|
|
|
|
|
if "," in image_data:
|
|
|
|
|
image_data = image_data.split(",", 1)[1]
|
|
|
|
|
img_bytes = base64.b64decode(image_data)
|
|
|
|
|
img_nparray = np.frombuffer(img_bytes, np.uint8)
|
2025-04-16 14:23:02 +00:00
|
|
|
|
2025-12-15 22:31:13 +08:00
|
|
|
elif "image_url" in body:
|
|
|
|
|
# Download image from URL
|
|
|
|
|
image_url = body["image_url"]
|
|
|
|
|
response = requests.get(image_url, timeout=30)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
img_bytes = response.content
|
|
|
|
|
img_nparray = np.frombuffer(img_bytes, np.uint8)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
return JSONResponse({"error": "Either 'image_base64' or 'image_url' must be provided"}, status_code=400)
|
|
|
|
|
|
|
|
|
|
# Decode and convert image
|
|
|
|
|
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
|
|
|
|
if img_nparray is None:
|
|
|
|
|
return JSONResponse({"error": "Failed to decode image"}, status_code=400)
|
|
|
|
|
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
|
|
|
|
# Get prediction
|
|
|
|
|
pred = await self.texteller_server.predict.remote(img_nparray)
|
|
|
|
|
|
|
|
|
|
return JSONResponse({"result": pred})
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|