feat add glm-ocr core

This commit is contained in:
liuyuanchuang
2026-03-09 16:51:06 +08:00
parent d74130914c
commit 6dfaf9668b
17 changed files with 1687 additions and 140 deletions

View File

@@ -0,0 +1,98 @@
import numpy as np
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.api.v1.endpoints.image import router
from app.core.dependencies import get_glmocr_endtoend_service, get_image_processor
class _FakeImageProcessor:
def preprocess(self, image_url=None, image_base64=None):
return np.zeros((8, 8, 3), dtype=np.uint8)
class _FakeOCRService:
def __init__(self, result=None, error=None):
self._result = result or {"markdown": "md", "latex": "tex", "mathml": "mml", "mml": "xml"}
self._error = error
def recognize(self, image):
if self._error:
raise self._error
return self._result
def _build_client(image_processor=None, ocr_service=None):
app = FastAPI()
app.include_router(router)
app.dependency_overrides[get_image_processor] = lambda: image_processor or _FakeImageProcessor()
app.dependency_overrides[get_glmocr_endtoend_service] = lambda: ocr_service or _FakeOCRService()
return TestClient(app)
def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
client = _build_client()
missing = client.post("/ocr", json={})
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
assert missing.status_code == 422
assert both.status_code == 422
def test_image_endpoint_returns_503_for_runtime_error():
client = _build_client(ocr_service=_FakeOCRService(error=RuntimeError("backend unavailable")))
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
assert response.status_code == 503
assert response.json()["detail"] == "backend unavailable"
def test_image_endpoint_returns_500_for_unexpected_error():
client = _build_client(ocr_service=_FakeOCRService(error=ValueError("boom")))
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
assert response.status_code == 500
assert response.json()["detail"] == "Internal server error"
def test_image_endpoint_returns_ocr_payload():
client = _build_client()
response = client.post("/ocr", json={"image_base64": "ZmFrZQ=="})
assert response.status_code == 200
assert response.json() == {
"latex": "tex",
"markdown": "md",
"mathml": "mml",
"mml": "xml",
"layout_info": {"regions": [], "MixedRecognition": False},
"recognition_mode": "",
}
def test_image_endpoint_real_e2e_with_env_services():
from app.main import app
image_url = (
"https://static.texpixel.com/formula/012dab3e-fb31-4ecd-90fc-6957458ee309.png"
"?Expires=1773049821&OSSAccessKeyId=TMP.3KnrJUz7aXHoU9rLTAih4MAyPGd9zyGRHiqg9AyH6TY6NKtzqT2yr4qo7Vwf8fMRFCBrWXiCFrbBwC3vn7U6mspV2NeU1K"
"&Signature=oynhP0OLIgFI0Sv3z2CWeHPT2Ck%3D"
)
with TestClient(app) as client:
response = client.post(
"/doc_process/v1/image/ocr",
json={"image_url": image_url},
headers={"x-request-id": "test-e2e"},
)
assert response.status_code == 200, response.text
payload = response.json()
assert isinstance(payload["markdown"], str)
assert payload["markdown"].strip()
assert set(payload) >= {"markdown", "latex", "mathml", "mml"}

View File

@@ -0,0 +1,10 @@
import pytest
from app.core import dependencies
def test_get_glmocr_endtoend_service_raises_when_layout_detector_missing(monkeypatch):
monkeypatch.setattr(dependencies, "_layout_detector", None)
with pytest.raises(RuntimeError, match="Layout detector not initialized"):
dependencies.get_glmocr_endtoend_service()

View File

@@ -0,0 +1,31 @@
from app.schemas.image import ImageOCRRequest, LayoutRegion
def test_layout_region_native_label_defaults_to_empty_string():
region = LayoutRegion(
type="text",
bbox=[0, 0, 10, 10],
confidence=0.9,
score=0.9,
)
assert region.native_label == ""
def test_layout_region_exposes_native_label_when_provided():
region = LayoutRegion(
type="text",
native_label="doc_title",
bbox=[0, 0, 10, 10],
confidence=0.9,
score=0.9,
)
assert region.native_label == "doc_title"
def test_image_ocr_request_requires_exactly_one_input():
request = ImageOCRRequest(image_url="https://example.com/test.png")
assert request.image_url == "https://example.com/test.png"
assert request.image_base64 is None

View File

@@ -0,0 +1,199 @@
from app.services.glm_postprocess import (
GLMResultFormatter,
clean_formula_number,
clean_repeated_content,
find_consecutive_repeat,
)
def test_find_consecutive_repeat_truncates_when_threshold_met():
repeated = "abcdefghij" * 10 + "tail"
assert find_consecutive_repeat(repeated) == "abcdefghij"
def test_find_consecutive_repeat_returns_none_when_below_threshold():
assert find_consecutive_repeat("abcdefghij" * 9) is None
def test_clean_repeated_content_handles_consecutive_and_line_level_repeats():
assert clean_repeated_content("abcdefghij" * 10 + "tail") == "abcdefghij"
line_repeated = "\n".join(["same line"] * 10 + ["other"])
assert clean_repeated_content(line_repeated, line_threshold=10) == "same line\n"
assert clean_repeated_content("normal text") == "normal text"
def test_clean_formula_number_strips_wrapping_parentheses():
assert clean_formula_number("(1)") == "1"
assert clean_formula_number("2.1") == "2.1"
assert clean_formula_number("3") == "3"
def test_clean_content_removes_literal_tabs_and_long_repeat_noise():
formatter = GLMResultFormatter()
noisy = r"\t\t" + ("·" * 5) + ("abcdefghij" * 205) + r"\t"
cleaned = formatter._clean_content(noisy)
assert cleaned.startswith("···")
assert cleaned.endswith("abcdefghij")
assert r"\t" not in cleaned
def test_format_content_handles_titles_formula_text_and_newlines():
formatter = GLMResultFormatter()
assert formatter._format_content("Intro", "text", "doc_title") == "# Intro"
assert formatter._format_content("- Section", "text", "paragraph_title") == "## Section"
assert formatter._format_content(r"\[x+y\]", "formula", "display_formula") == "$$\nx+y\n$$"
assert formatter._format_content("· item\nnext", "text", "text") == "- item\n\nnext"
def test_merge_formula_numbers_merges_before_and_after_formula():
formatter = GLMResultFormatter()
before = formatter._merge_formula_numbers(
[
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
]
)
after = formatter._merge_formula_numbers(
[
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
]
)
untouched = formatter._merge_formula_numbers(
[{"index": 0, "label": "text", "native_label": "formula_number", "content": "(3)"}]
)
assert before == [
{
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y \\tag{1}\n$$",
}
]
assert after == [
{
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y \\tag{2}\n$$",
}
]
assert untouched == []
def test_merge_text_blocks_joins_hyphenated_words_when_wordfreq_accepts(monkeypatch):
formatter = GLMResultFormatter()
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 3.0)
merged = formatter._merge_text_blocks(
[
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "national"},
]
)
assert merged == [
{"index": 0, "label": "text", "native_label": "text", "content": "international"}
]
def test_merge_text_blocks_skips_invalid_merge(monkeypatch):
formatter = GLMResultFormatter()
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 1.0)
merged = formatter._merge_text_blocks(
[
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
]
)
assert merged == [
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
]
def test_format_bullet_points_infers_missing_middle_bullet():
formatter = GLMResultFormatter()
items = [
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
{"native_label": "text", "content": "second", "bbox_2d": [12, 12, 52, 22]},
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
]
formatted = formatter._format_bullet_points(items)
assert formatted[1]["content"] == "- second"
def test_format_bullet_points_skips_when_bbox_missing():
formatter = GLMResultFormatter()
items = [
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
{"native_label": "text", "content": "second", "bbox_2d": []},
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
]
formatted = formatter._format_bullet_points(items)
assert formatted[1]["content"] == "second"
def test_process_runs_full_pipeline_and_skips_empty_content():
formatter = GLMResultFormatter()
regions = [
{
"index": 0,
"label": "text",
"native_label": "doc_title",
"content": "Doc Title",
"bbox_2d": [0, 0, 100, 30],
},
{
"index": 1,
"label": "text",
"native_label": "formula_number",
"content": "(1)",
"bbox_2d": [80, 50, 100, 60],
},
{
"index": 2,
"label": "formula",
"native_label": "display_formula",
"content": "x+y",
"bbox_2d": [0, 40, 100, 80],
},
{
"index": 3,
"label": "figure",
"native_label": "image",
"content": "figure placeholder",
"bbox_2d": [0, 80, 100, 120],
},
{
"index": 4,
"label": "text",
"native_label": "text",
"content": "",
"bbox_2d": [0, 120, 100, 150],
},
]
output = formatter.process(regions)
assert "# Doc Title" in output
assert "$$\nx+y \\tag{1}\n$$" in output
assert "![](bbox=[0, 80, 100, 120])" in output

View File

@@ -0,0 +1,46 @@
import numpy as np
from app.services.layout_detector import LayoutDetector
class _FakePredictor:
def __init__(self, boxes):
self._boxes = boxes
def predict(self, image):
return [{"boxes": self._boxes}]
def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
raw_boxes = [
{"cls_id": 22, "label": "text", "score": 0.95, "coordinate": [0, 0, 100, 100]},
{"cls_id": 22, "label": "text", "score": 0.90, "coordinate": [10, 10, 20, 20]},
{"cls_id": 6, "label": "doc_title", "score": 0.99, "coordinate": [0, 0, 80, 20]},
]
detector = LayoutDetector.__new__(LayoutDetector)
detector._get_layout_detector = lambda: _FakePredictor(raw_boxes)
calls = {}
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
calls["args"] = {
"boxes": boxes,
"img_size": img_size,
"layout_nms": layout_nms,
"layout_unclip_ratio": layout_unclip_ratio,
"layout_merge_bboxes_mode": layout_merge_bboxes_mode,
}
return [boxes[0], boxes[2]]
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
image = np.zeros((200, 100, 3), dtype=np.uint8)
info = detector.detect(image)
assert calls["args"]["img_size"] == (100, 200)
assert calls["args"]["layout_nms"] is True
assert calls["args"]["layout_merge_bboxes_mode"] == "large"
assert [region.native_label for region in info.regions] == ["text", "doc_title"]
assert [region.type for region in info.regions] == ["text", "text"]
assert info.MixedRecognition is True

View File

@@ -0,0 +1,151 @@
import math
import numpy as np
from app.services.layout_postprocess import (
apply_layout_postprocess,
check_containment,
iou,
is_contained,
nms,
unclip_boxes,
)
def _raw_box(cls_id, score, x1, y1, x2, y2, label="text"):
return {
"cls_id": cls_id,
"label": label,
"score": score,
"coordinate": [x1, y1, x2, y2],
}
def test_iou_handles_full_none_and_partial_overlap():
assert iou([0, 0, 9, 9], [0, 0, 9, 9]) == 1.0
assert iou([0, 0, 9, 9], [20, 20, 29, 29]) == 0.0
assert math.isclose(iou([0, 0, 9, 9], [5, 5, 14, 14]), 1 / 7, rel_tol=1e-6)
def test_nms_keeps_highest_score_for_same_class_overlap():
boxes = np.array(
[
[0, 0.95, 0, 0, 10, 10],
[0, 0.80, 1, 1, 11, 11],
],
dtype=float,
)
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
assert kept == [0]
def test_nms_keeps_cross_class_overlap_boxes_below_diff_threshold():
boxes = np.array(
[
[0, 0.95, 0, 0, 10, 10],
[1, 0.90, 1, 1, 11, 11],
],
dtype=float,
)
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
assert kept == [0, 1]
def test_nms_returns_single_box_index():
boxes = np.array([[0, 0.95, 0, 0, 10, 10]], dtype=float)
assert nms(boxes) == [0]
def test_is_contained_uses_overlap_threshold():
outer = [0, 0.9, 0, 0, 10, 10]
inner = [0, 0.9, 2, 2, 8, 8]
partial = [0, 0.9, 6, 6, 12, 12]
assert is_contained(inner, outer) is True
assert is_contained(partial, outer) is False
assert is_contained(partial, outer, overlap_threshold=0.3) is True
def test_check_containment_respects_preserve_class_ids():
boxes = np.array(
[
[0, 0.9, 0, 0, 100, 100],
[1, 0.8, 10, 10, 30, 30],
[2, 0.7, 15, 15, 25, 25],
],
dtype=float,
)
contains_other, contained_by_other = check_containment(boxes, preserve_cls_ids={1})
assert contains_other.tolist() == [1, 1, 0]
assert contained_by_other.tolist() == [0, 0, 1]
def test_unclip_boxes_supports_scalar_tuple_dict_and_none():
boxes = np.array(
[
[0, 0.9, 10, 10, 20, 20],
[1, 0.8, 30, 30, 50, 40],
],
dtype=float,
)
scalar = unclip_boxes(boxes, 2.0)
assert scalar[:, 2:6].tolist() == [[5.0, 5.0, 25.0, 25.0], [20.0, 25.0, 60.0, 45.0]]
tuple_ratio = unclip_boxes(boxes, (2.0, 3.0))
assert tuple_ratio[:, 2:6].tolist() == [[5.0, 0.0, 25.0, 30.0], [20.0, 20.0, 60.0, 50.0]]
per_class = unclip_boxes(boxes, {1: (1.5, 2.0)})
assert per_class[:, 2:6].tolist() == [[10.0, 10.0, 20.0, 20.0], [25.0, 25.0, 55.0, 45.0]]
assert np.array_equal(unclip_boxes(boxes, None), boxes)
def test_apply_layout_postprocess_large_mode_removes_contained_small_box():
boxes = [
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
_raw_box(0, 0.90, 10, 10, 20, 20, "text"),
]
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
assert [box["coordinate"] for box in result] == [[0, 0, 100, 100]]
def test_apply_layout_postprocess_preserves_contained_image_like_boxes():
boxes = [
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
_raw_box(1, 0.90, 10, 10, 20, 20, "image"),
_raw_box(2, 0.90, 25, 25, 35, 35, "seal"),
_raw_box(3, 0.90, 40, 40, 50, 50, "chart"),
]
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
assert {box["label"] for box in result} == {"text", "image", "seal", "chart"}
def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image():
boxes = [
_raw_box(0, 0.95, -10, -5, 40, 50, "text"),
_raw_box(1, 0.90, 10, 10, 10, 50, "text"),
_raw_box(2, 0.85, 0, 0, 100, 90, "image"),
]
result = apply_layout_postprocess(
boxes,
img_size=(100, 90),
layout_nms=False,
layout_merge_bboxes_mode=None,
)
assert result == [
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
]

View File

@@ -0,0 +1,124 @@
import base64
from types import SimpleNamespace
import cv2
import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
from app.services.ocr_service import GLMOCREndToEndService
class _FakeConverter:
def convert_to_formats(self, markdown):
return SimpleNamespace(
latex=f"LATEX::{markdown}",
mathml=f"MATHML::{markdown}",
mml=f"MML::{markdown}",
)
class _FakeImageProcessor:
def add_padding(self, image):
return image
class _FakeLayoutDetector:
def __init__(self, regions):
self._regions = regions
def detect(self, image):
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
def _build_service(regions=None):
return GLMOCREndToEndService(
vl_server_url="http://127.0.0.1:8002/v1",
image_processor=_FakeImageProcessor(),
converter=_FakeConverter(),
layout_detector=_FakeLayoutDetector(regions or []),
max_workers=2,
)
def test_encode_region_returns_decodable_base64_jpeg():
service = _build_service()
image = np.zeros((8, 12, 3), dtype=np.uint8)
image[:, :] = [0, 128, 255]
encoded = service._encode_region(image)
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
assert decoded.shape[:2] == image.shape[:2]
def test_call_vllm_builds_messages_and_returns_content():
service = _build_service()
captured = {}
def create(**kwargs):
captured.update(kwargs)
return SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
)
service.openai_client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
)
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
assert result == "recognized content"
assert captured["model"] == "glm-ocr"
assert captured["max_tokens"] == 1024
assert captured["messages"][0]["content"][0]["type"] == "image_url"
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
def test_normalize_bbox_scales_coordinates_to_1000():
service = _build_service()
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
service = _build_service(regions=[])
image = np.zeros((20, 30, 3), dtype=np.uint8)
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
result = service.recognize(image)
assert result["markdown"] == "raw text"
assert result["latex"] == "LATEX::raw text"
assert result["mathml"] == "MATHML::raw text"
assert result["mml"] == "MML::raw text"
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
regions = [
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
]
service = _build_service(regions=regions)
image = np.zeros((40, 40, 3), dtype=np.uint8)
calls = []
def fake_call_vllm(cropped, prompt):
calls.append(prompt)
if prompt == "Text Recognition:":
return "Title"
return "x + y"
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
result = service.recognize(image)
assert calls == ["Text Recognition:", "Formula Recognition:"]
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"