feat add glm-ocr core
This commit is contained in:
199
tests/services/test_glm_postprocess.py
Normal file
199
tests/services/test_glm_postprocess.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from app.services.glm_postprocess import (
|
||||
GLMResultFormatter,
|
||||
clean_formula_number,
|
||||
clean_repeated_content,
|
||||
find_consecutive_repeat,
|
||||
)
|
||||
|
||||
|
||||
def test_find_consecutive_repeat_truncates_when_threshold_met():
|
||||
repeated = "abcdefghij" * 10 + "tail"
|
||||
|
||||
assert find_consecutive_repeat(repeated) == "abcdefghij"
|
||||
|
||||
|
||||
def test_find_consecutive_repeat_returns_none_when_below_threshold():
|
||||
assert find_consecutive_repeat("abcdefghij" * 9) is None
|
||||
|
||||
|
||||
def test_clean_repeated_content_handles_consecutive_and_line_level_repeats():
|
||||
assert clean_repeated_content("abcdefghij" * 10 + "tail") == "abcdefghij"
|
||||
|
||||
line_repeated = "\n".join(["same line"] * 10 + ["other"])
|
||||
assert clean_repeated_content(line_repeated, line_threshold=10) == "same line\n"
|
||||
|
||||
assert clean_repeated_content("normal text") == "normal text"
|
||||
|
||||
|
||||
def test_clean_formula_number_strips_wrapping_parentheses():
|
||||
assert clean_formula_number("(1)") == "1"
|
||||
assert clean_formula_number("(2.1)") == "2.1"
|
||||
assert clean_formula_number("3") == "3"
|
||||
|
||||
|
||||
def test_clean_content_removes_literal_tabs_and_long_repeat_noise():
|
||||
formatter = GLMResultFormatter()
|
||||
noisy = r"\t\t" + ("·" * 5) + ("abcdefghij" * 205) + r"\t"
|
||||
|
||||
cleaned = formatter._clean_content(noisy)
|
||||
|
||||
assert cleaned.startswith("···")
|
||||
assert cleaned.endswith("abcdefghij")
|
||||
assert r"\t" not in cleaned
|
||||
|
||||
|
||||
def test_format_content_handles_titles_formula_text_and_newlines():
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
assert formatter._format_content("Intro", "text", "doc_title") == "# Intro"
|
||||
assert formatter._format_content("- Section", "text", "paragraph_title") == "## Section"
|
||||
assert formatter._format_content(r"\[x+y\]", "formula", "display_formula") == "$$\nx+y\n$$"
|
||||
assert formatter._format_content("· item\nnext", "text", "text") == "- item\n\nnext"
|
||||
|
||||
|
||||
def test_merge_formula_numbers_merges_before_and_after_formula():
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
before = formatter._merge_formula_numbers(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
|
||||
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
]
|
||||
)
|
||||
after = formatter._merge_formula_numbers(
|
||||
[
|
||||
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
|
||||
]
|
||||
)
|
||||
untouched = formatter._merge_formula_numbers(
|
||||
[{"index": 0, "label": "text", "native_label": "formula_number", "content": "(3)"}]
|
||||
)
|
||||
|
||||
assert before == [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y \\tag{1}\n$$",
|
||||
}
|
||||
]
|
||||
assert after == [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y \\tag{2}\n$$",
|
||||
}
|
||||
]
|
||||
assert untouched == []
|
||||
|
||||
|
||||
def test_merge_text_blocks_joins_hyphenated_words_when_wordfreq_accepts(monkeypatch):
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
|
||||
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 3.0)
|
||||
|
||||
merged = formatter._merge_text_blocks(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "national"},
|
||||
]
|
||||
)
|
||||
|
||||
assert merged == [
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "international"}
|
||||
]
|
||||
|
||||
|
||||
def test_merge_text_blocks_skips_invalid_merge(monkeypatch):
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
|
||||
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 1.0)
|
||||
|
||||
merged = formatter._merge_text_blocks(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
|
||||
]
|
||||
)
|
||||
|
||||
assert merged == [
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
|
||||
]
|
||||
|
||||
|
||||
def test_format_bullet_points_infers_missing_middle_bullet():
|
||||
formatter = GLMResultFormatter()
|
||||
items = [
|
||||
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
|
||||
{"native_label": "text", "content": "second", "bbox_2d": [12, 12, 52, 22]},
|
||||
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
|
||||
]
|
||||
|
||||
formatted = formatter._format_bullet_points(items)
|
||||
|
||||
assert formatted[1]["content"] == "- second"
|
||||
|
||||
|
||||
def test_format_bullet_points_skips_when_bbox_missing():
|
||||
formatter = GLMResultFormatter()
|
||||
items = [
|
||||
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
|
||||
{"native_label": "text", "content": "second", "bbox_2d": []},
|
||||
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
|
||||
]
|
||||
|
||||
formatted = formatter._format_bullet_points(items)
|
||||
|
||||
assert formatted[1]["content"] == "second"
|
||||
|
||||
|
||||
def test_process_runs_full_pipeline_and_skips_empty_content():
|
||||
formatter = GLMResultFormatter()
|
||||
regions = [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "text",
|
||||
"native_label": "doc_title",
|
||||
"content": "Doc Title",
|
||||
"bbox_2d": [0, 0, 100, 30],
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"label": "text",
|
||||
"native_label": "formula_number",
|
||||
"content": "(1)",
|
||||
"bbox_2d": [80, 50, 100, 60],
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "x+y",
|
||||
"bbox_2d": [0, 40, 100, 80],
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"label": "figure",
|
||||
"native_label": "image",
|
||||
"content": "figure placeholder",
|
||||
"bbox_2d": [0, 80, 100, 120],
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"label": "text",
|
||||
"native_label": "text",
|
||||
"content": "",
|
||||
"bbox_2d": [0, 120, 100, 150],
|
||||
},
|
||||
]
|
||||
|
||||
output = formatter.process(regions)
|
||||
|
||||
assert "# Doc Title" in output
|
||||
assert "$$\nx+y \\tag{1}\n$$" in output
|
||||
assert "" in output
|
||||
46
tests/services/test_layout_detector.py
Normal file
46
tests/services/test_layout_detector.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
|
||||
|
||||
class _FakePredictor:
|
||||
def __init__(self, boxes):
|
||||
self._boxes = boxes
|
||||
|
||||
def predict(self, image):
|
||||
return [{"boxes": self._boxes}]
|
||||
|
||||
|
||||
def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
||||
raw_boxes = [
|
||||
{"cls_id": 22, "label": "text", "score": 0.95, "coordinate": [0, 0, 100, 100]},
|
||||
{"cls_id": 22, "label": "text", "score": 0.90, "coordinate": [10, 10, 20, 20]},
|
||||
{"cls_id": 6, "label": "doc_title", "score": 0.99, "coordinate": [0, 0, 80, 20]},
|
||||
]
|
||||
|
||||
detector = LayoutDetector.__new__(LayoutDetector)
|
||||
detector._get_layout_detector = lambda: _FakePredictor(raw_boxes)
|
||||
|
||||
calls = {}
|
||||
|
||||
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
|
||||
calls["args"] = {
|
||||
"boxes": boxes,
|
||||
"img_size": img_size,
|
||||
"layout_nms": layout_nms,
|
||||
"layout_unclip_ratio": layout_unclip_ratio,
|
||||
"layout_merge_bboxes_mode": layout_merge_bboxes_mode,
|
||||
}
|
||||
return [boxes[0], boxes[2]]
|
||||
|
||||
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
|
||||
|
||||
image = np.zeros((200, 100, 3), dtype=np.uint8)
|
||||
info = detector.detect(image)
|
||||
|
||||
assert calls["args"]["img_size"] == (100, 200)
|
||||
assert calls["args"]["layout_nms"] is True
|
||||
assert calls["args"]["layout_merge_bboxes_mode"] == "large"
|
||||
assert [region.native_label for region in info.regions] == ["text", "doc_title"]
|
||||
assert [region.type for region in info.regions] == ["text", "text"]
|
||||
assert info.MixedRecognition is True
|
||||
151
tests/services/test_layout_postprocess.py
Normal file
151
tests/services/test_layout_postprocess.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.services.layout_postprocess import (
|
||||
apply_layout_postprocess,
|
||||
check_containment,
|
||||
iou,
|
||||
is_contained,
|
||||
nms,
|
||||
unclip_boxes,
|
||||
)
|
||||
|
||||
|
||||
def _raw_box(cls_id, score, x1, y1, x2, y2, label="text"):
|
||||
return {
|
||||
"cls_id": cls_id,
|
||||
"label": label,
|
||||
"score": score,
|
||||
"coordinate": [x1, y1, x2, y2],
|
||||
}
|
||||
|
||||
|
||||
def test_iou_handles_full_none_and_partial_overlap():
|
||||
assert iou([0, 0, 9, 9], [0, 0, 9, 9]) == 1.0
|
||||
assert iou([0, 0, 9, 9], [20, 20, 29, 29]) == 0.0
|
||||
assert math.isclose(iou([0, 0, 9, 9], [5, 5, 14, 14]), 1 / 7, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_nms_keeps_highest_score_for_same_class_overlap():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.95, 0, 0, 10, 10],
|
||||
[0, 0.80, 1, 1, 11, 11],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
||||
|
||||
assert kept == [0]
|
||||
|
||||
|
||||
def test_nms_keeps_cross_class_overlap_boxes_below_diff_threshold():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.95, 0, 0, 10, 10],
|
||||
[1, 0.90, 1, 1, 11, 11],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
||||
|
||||
assert kept == [0, 1]
|
||||
|
||||
|
||||
def test_nms_returns_single_box_index():
|
||||
boxes = np.array([[0, 0.95, 0, 0, 10, 10]], dtype=float)
|
||||
|
||||
assert nms(boxes) == [0]
|
||||
|
||||
|
||||
def test_is_contained_uses_overlap_threshold():
|
||||
outer = [0, 0.9, 0, 0, 10, 10]
|
||||
inner = [0, 0.9, 2, 2, 8, 8]
|
||||
partial = [0, 0.9, 6, 6, 12, 12]
|
||||
|
||||
assert is_contained(inner, outer) is True
|
||||
assert is_contained(partial, outer) is False
|
||||
assert is_contained(partial, outer, overlap_threshold=0.3) is True
|
||||
|
||||
|
||||
def test_check_containment_respects_preserve_class_ids():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.9, 0, 0, 100, 100],
|
||||
[1, 0.8, 10, 10, 30, 30],
|
||||
[2, 0.7, 15, 15, 25, 25],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
contains_other, contained_by_other = check_containment(boxes, preserve_cls_ids={1})
|
||||
|
||||
assert contains_other.tolist() == [1, 1, 0]
|
||||
assert contained_by_other.tolist() == [0, 0, 1]
|
||||
|
||||
|
||||
def test_unclip_boxes_supports_scalar_tuple_dict_and_none():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.9, 10, 10, 20, 20],
|
||||
[1, 0.8, 30, 30, 50, 40],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
scalar = unclip_boxes(boxes, 2.0)
|
||||
assert scalar[:, 2:6].tolist() == [[5.0, 5.0, 25.0, 25.0], [20.0, 25.0, 60.0, 45.0]]
|
||||
|
||||
tuple_ratio = unclip_boxes(boxes, (2.0, 3.0))
|
||||
assert tuple_ratio[:, 2:6].tolist() == [[5.0, 0.0, 25.0, 30.0], [20.0, 20.0, 60.0, 50.0]]
|
||||
|
||||
per_class = unclip_boxes(boxes, {1: (1.5, 2.0)})
|
||||
assert per_class[:, 2:6].tolist() == [[10.0, 10.0, 20.0, 20.0], [25.0, 25.0, 55.0, 45.0]]
|
||||
|
||||
assert np.array_equal(unclip_boxes(boxes, None), boxes)
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_large_mode_removes_contained_small_box():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
|
||||
_raw_box(0, 0.90, 10, 10, 20, 20, "text"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
|
||||
|
||||
assert [box["coordinate"] for box in result] == [[0, 0, 100, 100]]
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_preserves_contained_image_like_boxes():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
|
||||
_raw_box(1, 0.90, 10, 10, 20, 20, "image"),
|
||||
_raw_box(2, 0.90, 25, 25, 35, 35, "seal"),
|
||||
_raw_box(3, 0.90, 40, 40, 50, 50, "chart"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
|
||||
|
||||
assert {box["label"] for box in result} == {"text", "image", "seal", "chart"}
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, -10, -5, 40, 50, "text"),
|
||||
_raw_box(1, 0.90, 10, 10, 10, 50, "text"),
|
||||
_raw_box(2, 0.85, 0, 0, 100, 90, "image"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(
|
||||
boxes,
|
||||
img_size=(100, 90),
|
||||
layout_nms=False,
|
||||
layout_merge_bboxes_mode=None,
|
||||
)
|
||||
|
||||
assert result == [
|
||||
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
|
||||
]
|
||||
124
tests/services/test_ocr_service.py
Normal file
124
tests/services/test_ocr_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
|
||||
|
||||
class _FakeConverter:
|
||||
def convert_to_formats(self, markdown):
|
||||
return SimpleNamespace(
|
||||
latex=f"LATEX::{markdown}",
|
||||
mathml=f"MATHML::{markdown}",
|
||||
mml=f"MML::{markdown}",
|
||||
)
|
||||
|
||||
|
||||
class _FakeImageProcessor:
|
||||
def add_padding(self, image):
|
||||
return image
|
||||
|
||||
|
||||
class _FakeLayoutDetector:
|
||||
def __init__(self, regions):
|
||||
self._regions = regions
|
||||
|
||||
def detect(self, image):
|
||||
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
|
||||
|
||||
|
||||
def _build_service(regions=None):
|
||||
return GLMOCREndToEndService(
|
||||
vl_server_url="http://127.0.0.1:8002/v1",
|
||||
image_processor=_FakeImageProcessor(),
|
||||
converter=_FakeConverter(),
|
||||
layout_detector=_FakeLayoutDetector(regions or []),
|
||||
max_workers=2,
|
||||
)
|
||||
|
||||
|
||||
def test_encode_region_returns_decodable_base64_jpeg():
|
||||
service = _build_service()
|
||||
image = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
image[:, :] = [0, 128, 255]
|
||||
|
||||
encoded = service._encode_region(image)
|
||||
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
assert decoded.shape[:2] == image.shape[:2]
|
||||
|
||||
|
||||
def test_call_vllm_builds_messages_and_returns_content():
|
||||
service = _build_service()
|
||||
captured = {}
|
||||
|
||||
def create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
|
||||
)
|
||||
|
||||
service.openai_client = SimpleNamespace(
|
||||
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
|
||||
)
|
||||
|
||||
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
|
||||
|
||||
assert result == "recognized content"
|
||||
assert captured["model"] == "glm-ocr"
|
||||
assert captured["max_tokens"] == 1024
|
||||
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||||
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
||||
|
||||
|
||||
def test_normalize_bbox_scales_coordinates_to_1000():
|
||||
service = _build_service()
|
||||
|
||||
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
|
||||
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
|
||||
|
||||
|
||||
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
||||
service = _build_service(regions=[])
|
||||
image = np.zeros((20, 30, 3), dtype=np.uint8)
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert result["markdown"] == "raw text"
|
||||
assert result["latex"] == "LATEX::raw text"
|
||||
assert result["mathml"] == "MATHML::raw text"
|
||||
assert result["mml"] == "MML::raw text"
|
||||
|
||||
|
||||
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
||||
regions = [
|
||||
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
|
||||
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
|
||||
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
|
||||
]
|
||||
service = _build_service(regions=regions)
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_call_vllm(cropped, prompt):
|
||||
calls.append(prompt)
|
||||
if prompt == "Text Recognition:":
|
||||
return "Title"
|
||||
return "x + y"
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert calls == ["Text Recognition:", "Formula Recognition:"]
|
||||
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
|
||||
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"
|
||||
Reference in New Issue
Block a user