2026-03-09 16:51:06 +08:00
|
|
|
import base64
|
|
|
|
|
from types import SimpleNamespace
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from app.schemas.image import LayoutInfo, LayoutRegion
|
|
|
|
|
from app.services.ocr_service import GLMOCREndToEndService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _FakeConverter:
|
|
|
|
|
def convert_to_formats(self, markdown):
|
|
|
|
|
return SimpleNamespace(
|
|
|
|
|
latex=f"LATEX::{markdown}",
|
|
|
|
|
mathml=f"MATHML::{markdown}",
|
|
|
|
|
mml=f"MML::{markdown}",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _FakeImageProcessor:
|
|
|
|
|
def add_padding(self, image):
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _FakeLayoutDetector:
|
|
|
|
|
def __init__(self, regions):
|
|
|
|
|
self._regions = regions
|
|
|
|
|
|
|
|
|
|
def detect(self, image):
|
|
|
|
|
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_service(regions=None):
|
|
|
|
|
return GLMOCREndToEndService(
|
|
|
|
|
vl_server_url="http://127.0.0.1:8002/v1",
|
|
|
|
|
image_processor=_FakeImageProcessor(),
|
|
|
|
|
converter=_FakeConverter(),
|
|
|
|
|
layout_detector=_FakeLayoutDetector(regions or []),
|
|
|
|
|
max_workers=2,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_encode_region_returns_decodable_base64_jpeg():
|
|
|
|
|
service = _build_service()
|
|
|
|
|
image = np.zeros((8, 12, 3), dtype=np.uint8)
|
|
|
|
|
image[:, :] = [0, 128, 255]
|
|
|
|
|
|
|
|
|
|
encoded = service._encode_region(image)
|
2026-03-10 19:52:22 +08:00
|
|
|
decoded = cv2.imdecode(
|
|
|
|
|
np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR
|
|
|
|
|
)
|
2026-03-09 16:51:06 +08:00
|
|
|
|
|
|
|
|
assert decoded.shape[:2] == image.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_call_vllm_builds_messages_and_returns_content():
|
|
|
|
|
service = _build_service()
|
|
|
|
|
captured = {}
|
|
|
|
|
|
|
|
|
|
def create(**kwargs):
|
|
|
|
|
captured.update(kwargs)
|
|
|
|
|
return SimpleNamespace(
|
|
|
|
|
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
service.openai_client = SimpleNamespace(
|
|
|
|
|
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
|
|
|
|
|
|
|
|
|
|
assert result == "recognized content"
|
|
|
|
|
assert captured["model"] == "glm-ocr"
|
|
|
|
|
assert captured["max_tokens"] == 1024
|
|
|
|
|
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
2026-03-10 19:52:22 +08:00
|
|
|
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith(
|
|
|
|
|
"data:image/jpeg;base64,"
|
|
|
|
|
)
|
2026-03-09 16:51:06 +08:00
|
|
|
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_normalize_bbox_scales_coordinates_to_1000():
|
|
|
|
|
service = _build_service()
|
|
|
|
|
|
|
|
|
|
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
|
|
|
|
|
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
|
|
|
|
service = _build_service(regions=[])
|
|
|
|
|
image = np.zeros((20, 30, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
|
|
|
|
|
|
|
|
|
|
result = service.recognize(image)
|
|
|
|
|
|
|
|
|
|
assert result["markdown"] == "raw text"
|
|
|
|
|
assert result["latex"] == "LATEX::raw text"
|
|
|
|
|
assert result["mathml"] == "MATHML::raw text"
|
|
|
|
|
assert result["mml"] == "MML::raw text"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
|
|
|
|
regions = [
|
2026-03-10 19:52:22 +08:00
|
|
|
LayoutRegion(
|
|
|
|
|
type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9
|
|
|
|
|
),
|
|
|
|
|
LayoutRegion(
|
|
|
|
|
type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8
|
|
|
|
|
),
|
|
|
|
|
LayoutRegion(
|
|
|
|
|
type="formula",
|
|
|
|
|
native_label="display_formula",
|
|
|
|
|
bbox=[20, 20, 40, 40],
|
|
|
|
|
confidence=0.95,
|
|
|
|
|
score=0.95,
|
|
|
|
|
),
|
2026-03-09 16:51:06 +08:00
|
|
|
]
|
|
|
|
|
service = _build_service(regions=regions)
|
|
|
|
|
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
calls = []
|
|
|
|
|
|
|
|
|
|
def fake_call_vllm(cropped, prompt):
|
|
|
|
|
calls.append(prompt)
|
|
|
|
|
if prompt == "Text Recognition:":
|
|
|
|
|
return "Title"
|
|
|
|
|
return "x + y"
|
|
|
|
|
|
|
|
|
|
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
|
|
|
|
|
|
|
|
|
|
result = service.recognize(image)
|
|
|
|
|
|
|
|
|
|
assert calls == ["Text Recognition:", "Formula Recognition:"]
|
|
|
|
|
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
|
|
|
|
|
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
|
|
|
|
|
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
|
|
|
|
|
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"
|