feat add glm-ocr core
This commit is contained in:
124
tests/services/test_ocr_service.py
Normal file
124
tests/services/test_ocr_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
|
||||
|
||||
class _FakeConverter:
|
||||
def convert_to_formats(self, markdown):
|
||||
return SimpleNamespace(
|
||||
latex=f"LATEX::{markdown}",
|
||||
mathml=f"MATHML::{markdown}",
|
||||
mml=f"MML::{markdown}",
|
||||
)
|
||||
|
||||
|
||||
class _FakeImageProcessor:
|
||||
def add_padding(self, image):
|
||||
return image
|
||||
|
||||
|
||||
class _FakeLayoutDetector:
|
||||
def __init__(self, regions):
|
||||
self._regions = regions
|
||||
|
||||
def detect(self, image):
|
||||
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
|
||||
|
||||
|
||||
def _build_service(regions=None):
|
||||
return GLMOCREndToEndService(
|
||||
vl_server_url="http://127.0.0.1:8002/v1",
|
||||
image_processor=_FakeImageProcessor(),
|
||||
converter=_FakeConverter(),
|
||||
layout_detector=_FakeLayoutDetector(regions or []),
|
||||
max_workers=2,
|
||||
)
|
||||
|
||||
|
||||
def test_encode_region_returns_decodable_base64_jpeg():
|
||||
service = _build_service()
|
||||
image = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
image[:, :] = [0, 128, 255]
|
||||
|
||||
encoded = service._encode_region(image)
|
||||
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
assert decoded.shape[:2] == image.shape[:2]
|
||||
|
||||
|
||||
def test_call_vllm_builds_messages_and_returns_content():
|
||||
service = _build_service()
|
||||
captured = {}
|
||||
|
||||
def create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
|
||||
)
|
||||
|
||||
service.openai_client = SimpleNamespace(
|
||||
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
|
||||
)
|
||||
|
||||
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
|
||||
|
||||
assert result == "recognized content"
|
||||
assert captured["model"] == "glm-ocr"
|
||||
assert captured["max_tokens"] == 1024
|
||||
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||||
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
||||
|
||||
|
||||
def test_normalize_bbox_scales_coordinates_to_1000():
|
||||
service = _build_service()
|
||||
|
||||
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
|
||||
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
|
||||
|
||||
|
||||
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
||||
service = _build_service(regions=[])
|
||||
image = np.zeros((20, 30, 3), dtype=np.uint8)
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert result["markdown"] == "raw text"
|
||||
assert result["latex"] == "LATEX::raw text"
|
||||
assert result["mathml"] == "MATHML::raw text"
|
||||
assert result["mml"] == "MML::raw text"
|
||||
|
||||
|
||||
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
||||
regions = [
|
||||
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
|
||||
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
|
||||
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
|
||||
]
|
||||
service = _build_service(regions=regions)
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_call_vllm(cropped, prompt):
|
||||
calls.append(prompt)
|
||||
if prompt == "Text Recognition:":
|
||||
return "Title"
|
||||
return "x + y"
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert calls == ["Text Recognition:", "Formula Recognition:"]
|
||||
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
|
||||
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"
|
||||
Reference in New Issue
Block a user