feat add glm-ocr core

This commit is contained in:
liuyuanchuang
2026-03-09 16:51:06 +08:00
parent d74130914c
commit 6dfaf9668b
17 changed files with 1687 additions and 140 deletions

View File

@@ -0,0 +1,14 @@
{
"permissions": {
"allow": [
"WebFetch(domain:deepwiki.com)",
"WebFetch(domain:github.com)",
"Read(//private/tmp/**)",
"Bash(gh api repos/zai-org/GLM-OCR/contents/glmocr --jq '.[].name')",
"WebFetch(domain:raw.githubusercontent.com)",
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)"
]
}
}

View File

@@ -2,26 +2,17 @@
import time
import uuid
import cv2
from io import BytesIO
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from app.core.dependencies import (
get_image_processor,
get_layout_detector,
get_ocr_service,
get_mineru_ocr_service,
get_glmocr_service,
get_glmocr_endtoend_service,
)
from app.core.config import get_settings
from app.core.logging_config import get_logger, RequestIDAdapter
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
settings = get_settings()
from app.services.ocr_service import GLMOCREndToEndService
router = APIRouter()
logger = get_logger()
@@ -33,100 +24,38 @@ async def process_image_ocr(
http_request: Request,
response: Response,
image_processor: ImageProcessor = Depends(get_image_processor),
layout_detector: LayoutDetector = Depends(get_layout_detector),
mineru_service: MineruOCRService = Depends(get_mineru_ocr_service),
paddle_service: OCRService = Depends(get_ocr_service),
glmocr_service: GLMOCRService = Depends(get_glmocr_service),
glmocr_service: GLMOCREndToEndService = Depends(get_glmocr_endtoend_service),
) -> ImageOCRResponse:
"""Process an image and extract content as LaTeX, Markdown, and MathML.
The processing pipeline:
1. Load and preprocess image (add 30% whitespace padding)
2. Detect layout using DocLayout-YOLO
3. Based on layout:
- If plain text exists: use PP-DocLayoutV2 for mixed recognition
- Otherwise: use PaddleOCR-VL with formula prompt
4. Convert output to LaTeX, Markdown, and MathML formats
1. Load and preprocess image
2. Detect layout regions using PP-DocLayoutV3
3. Crop each region and recognize with GLM-OCR via vLLM (task-specific prompts)
4. Aggregate region results into Markdown
5. Convert to LaTeX, Markdown, and MathML formats
Note: OMML conversion is not included due to performance overhead.
Use the /convert/latex-to-omml endpoint to convert LaTeX to OMML separately.
"""
# Get or generate request ID
request_id = http_request.headers.get("x-request-id", str(uuid.uuid4()))
response.headers["x-request-id"] = request_id
# Create logger adapter with request_id
log = RequestIDAdapter(logger, {"request_id": request_id})
log.request_id = request_id
try:
log.info("Starting image OCR processing")
start = time.time()
# Preprocess image (load only, no padding yet)
preprocess_start = time.time()
image = image_processor.preprocess(
image_url=request.image_url,
image_base64=request.image_base64,
)
# Apply padding only for layout detection
processed_image = image
if image_processor and settings.is_padding:
processed_image = image_processor.add_padding(image)
ocr_result = glmocr_service.recognize(image)
preprocess_time = time.time() - preprocess_start
log.debug(f"Image loading completed in {preprocess_time:.3f}s")
# Layout detection (using padded image if padding is enabled)
layout_start = time.time()
layout_info = layout_detector.detect(processed_image)
layout_time = time.time() - layout_start
log.info(f"Layout detection completed in {layout_time:.3f}s")
# OCR recognition (use original image without padding)
ocr_start = time.time()
if layout_info.MixedRecognition:
recognition_method = "MixedRecognition (MinerU)"
log.info(f"Using {recognition_method}")
# Convert original image (without padding) to bytes
success, encoded_image = cv2.imencode(".png", image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0) # Ensure position is at the beginning
ocr_result = mineru_service.recognize(image_bytes)
else:
recognition_method = "FormulaOnly (GLMOCR)"
log.info(f"Using {recognition_method}")
# Try GLM-OCR first, fallback to MinerU if token limit exceeded
try:
ocr_result = glmocr_service.recognize(image)
except Exception as e:
error_msg = str(e)
# Check if error is due to token limit (max_model_len exceeded)
if "max_model_len" in error_msg or "decoder prompt" in error_msg or "BadRequestError" in error_msg:
log.warning(f"GLM-OCR failed due to token limit: {error_msg}")
log.info("Falling back to MinerU for recognition")
recognition_method = "FormulaOnly (MinerU fallback)"
# Convert original image to bytes for MinerU
success, encoded_image = cv2.imencode(".png", image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0)
ocr_result = mineru_service.recognize(image_bytes)
else:
# Re-raise other errors
raise
ocr_time = time.time() - ocr_start
total_time = time.time() - preprocess_start
log.info(f"OCR processing completed - Method: {recognition_method}, " f"Layout time: {layout_time:.3f}s, OCR time: {ocr_time:.3f}s, " f"Total time: {total_time:.3f}s")
log.info(f"OCR completed in {time.time() - start:.3f}s")
except RuntimeError as e:
log.error(f"OCR processing failed: {str(e)}", exc_info=True)

View File

@@ -3,9 +3,8 @@
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
import torch
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
@@ -48,21 +47,25 @@ class Settings(BaseSettings):
is_padding: bool = True
padding_ratio: float = 0.1
max_tokens: int = 4096
# Model Paths
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
pp_doclayout_model_dir: str | None = (
"/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
)
# Image Processing
max_image_size_mb: int = 10
image_padding_ratio: float = 0.1 # 10% on each side = 20% total expansion
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # cuda:0 or cpu
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Server Settings
host: str = "0.0.0.0"
port: int = 8053
# Logging Settings
log_dir: Optional[str] = None # Defaults to /app/logs in container or ./logs locally
log_dir: str | None = None # Defaults to /app/logs in container or ./logs locally
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
@property

View File

@@ -2,7 +2,7 @@
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
from app.services.ocr_service import GLMOCREndToEndService
from app.services.converter import Converter
from app.core.config import get_settings
@@ -31,40 +31,17 @@ def get_image_processor() -> ImageProcessor:
return ImageProcessor()
def get_ocr_service() -> OCRService:
"""Get an OCR service instance."""
return OCRService(
vl_server_url=get_settings().paddleocr_vl_url,
layout_detector=get_layout_detector(),
image_processor=get_image_processor(),
converter=get_converter(),
)
def get_converter() -> Converter:
"""Get a DOCX converter instance."""
return Converter()
def get_mineru_ocr_service() -> MineruOCRService:
"""Get a MinerOCR service instance."""
def get_glmocr_endtoend_service() -> GLMOCREndToEndService:
"""Get end-to-end GLM-OCR service (layout detection + per-region OCR)."""
settings = get_settings()
api_url = getattr(settings, "miner_ocr_api_url", "http://127.0.0.1:8000/file_parse")
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://localhost:8002/v1")
return MineruOCRService(
api_url=api_url,
converter=get_converter(),
image_processor=get_image_processor(),
glm_ocr_url=glm_ocr_url,
)
def get_glmocr_service() -> GLMOCRService:
"""Get a GLM OCR service instance."""
settings = get_settings()
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://127.0.0.1:8002/v1")
return GLMOCRService(
vl_server_url=glm_ocr_url,
return GLMOCREndToEndService(
vl_server_url=settings.glm_ocr_url,
image_processor=get_image_processor(),
converter=get_converter(),
layout_detector=get_layout_detector(),
)

View File

@@ -7,6 +7,7 @@ class LayoutRegion(BaseModel):
"""A detected layout region in the document."""
type: str = Field(..., description="Region type: text, formula, table, figure")
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
confidence: float = Field(..., description="Detection confidence score")
score: float = Field(..., description="Detection score")

View File

@@ -0,0 +1,412 @@
"""GLM-OCR postprocessing logic adapted for this project.
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
glm-ocr/glmocr/utils/result_postprocess_utils.py.
Covers:
- Repeated-content / hallucination detection
- Per-region content cleaning and formatting (titles, bullets, formulas)
- formula_number merging (→ \\tag{})
- Hyphenated text-block merging (via wordfreq)
- Missing bullet-point detection
"""
from __future__ import annotations
import re
import json
from collections import Counter
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
try:
from wordfreq import zipf_frequency
_WORDFREQ_AVAILABLE = True
except ImportError:
_WORDFREQ_AVAILABLE = False
# ---------------------------------------------------------------------------
# result_postprocess_utils (ported)
# ---------------------------------------------------------------------------
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
"""Detect and truncate a consecutively-repeated pattern.
Returns the string with the repeat removed, or None if not found.
"""
n = len(s)
if n < min_unit_len * min_repeats:
return None
max_unit_len = n // min_repeats
if max_unit_len < min_unit_len:
return None
pattern = re.compile(
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
re.DOTALL,
)
match = pattern.search(s)
if match:
return s[: match.start()] + match.group(1)
return None
def clean_repeated_content(
content: str,
min_len: int = 10,
min_repeats: int = 10,
line_threshold: int = 10,
) -> str:
"""Remove hallucination-style repeated content (consecutive or line-level)."""
stripped = content.strip()
if not stripped:
return content
# 1. Consecutive repeat (multi-line aware)
if len(stripped) > min_len * min_repeats:
result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats)
if result is not None:
return result
# 2. Line-level repeat
lines = [line.strip() for line in content.split("\n") if line.strip()]
total_lines = len(lines)
if total_lines >= line_threshold and lines:
common, count = Counter(lines).most_common(1)[0]
if count >= line_threshold and (count / total_lines) >= 0.8:
for i, line in enumerate(lines):
if line == common:
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
if consecutive >= 3:
original_lines = content.split("\n")
non_empty_count = 0
for idx, orig_line in enumerate(original_lines):
if orig_line.strip():
non_empty_count += 1
if non_empty_count == i + 1:
return "\n".join(original_lines[: idx + 1])
break
return content
def clean_formula_number(number_content: str) -> str:
"""Strip parentheses from a formula number string, e.g. '(1)''1'."""
s = number_content.strip()
if s.startswith("(") and s.endswith(")"):
return s[1:-1]
if s.startswith("") and s.endswith(""):
return s[1:-1]
return s
# ---------------------------------------------------------------------------
# GLMResultFormatter
# ---------------------------------------------------------------------------
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
_LABEL_TO_CATEGORY: Dict[str, str] = {
# text
"abstract": "text",
"algorithm": "text",
"content": "text",
"doc_title": "text",
"figure_title": "text",
"paragraph_title": "text",
"reference_content": "text",
"text": "text",
"vertical_text": "text",
"vision_footnote": "text",
"seal": "text",
"formula_number": "text",
# table
"table": "table",
# formula
"display_formula": "formula",
"inline_formula": "formula",
# image (skip OCR)
"chart": "image",
"image": "image",
}
class GLMResultFormatter:
"""Port of GLM-OCR's ResultFormatter for use in our pipeline.
Accepts a list of region dicts (each with label, native_label, content,
bbox_2d) and returns a final Markdown string.
"""
# ------------------------------------------------------------------ #
# Public entry-point
# ------------------------------------------------------------------ #
def process(self, regions: List[Dict[str, Any]]) -> str:
"""Run the full postprocessing pipeline and return Markdown.
Args:
regions: List of dicts with keys:
- index (int) reading order from layout detection
- label (str) mapped category: text/formula/table/figure
- native_label (str) raw PP-DocLayout label (e.g. doc_title)
- content (str) raw OCR output from vLLM
- bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords
Returns:
Markdown string.
"""
# Sort by reading order
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
# Per-region cleaning + formatting
processed: List[Dict] = []
for item in items:
item["native_label"] = item.get("native_label", item.get("label", "text"))
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
item["content"] = self._format_content(
item.get("content") or "",
item["label"],
item["native_label"],
)
if not (item.get("content") or "").strip():
continue
processed.append(item)
# Re-index
for i, item in enumerate(processed):
item["index"] = i
# Structural merges
processed = self._merge_formula_numbers(processed)
processed = self._merge_text_blocks(processed)
processed = self._format_bullet_points(processed)
# Assemble Markdown
parts: List[str] = []
for item in processed:
content = item.get("content") or ""
if item["label"] == "image":
parts.append(f"![](bbox={item.get('bbox_2d', [])})")
elif content.strip():
parts.append(content)
return "\n\n".join(parts)
# ------------------------------------------------------------------ #
# Label mapping
# ------------------------------------------------------------------ #
def _map_label(self, label: str, native_label: str) -> str:
return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text"))
# ------------------------------------------------------------------ #
# Content cleaning
# ------------------------------------------------------------------ #
def _clean_content(self, content: str) -> str:
"""Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats."""
if content is None:
return ""
content = re.sub(r"^(\\t)+", "", content).lstrip()
content = re.sub(r"(\\t)+$", "", content).rstrip()
content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(·)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(_)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content)
if len(content) >= 2048:
content = clean_repeated_content(content)
return content.strip()
# ------------------------------------------------------------------ #
# Per-region content formatting
# ------------------------------------------------------------------ #
def _format_content(self, content: Any, label: str, native_label: str) -> str:
"""Clean and format a single region's content."""
if content is None:
return ""
content = self._clean_content(str(content))
# Heading formatting
if native_label == "doc_title":
content = re.sub(r"^#+\s*", "", content)
content = "# " + content
elif native_label == "paragraph_title":
if content.startswith("- ") or content.startswith("* "):
content = content[2:].lstrip()
content = re.sub(r"^#+\s*", "", content)
content = "## " + content.lstrip()
# Formula wrapping
if label == "formula":
content = content.strip()
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
if content.startswith(s) and content.endswith(e):
content = content[len(s) : -len(e)].strip()
break
content = "$$\n" + content + "\n$$"
# Text formatting
if label == "text":
if content.startswith("·") or content.startswith("") or content.startswith("* "):
content = "- " + content[1:].lstrip()
match = re.match(r"^(\(|\)(\d+|[A-Za-z])(\)|\)(.*)$", content)
if match:
_, symbol, _, rest = match.groups()
content = f"({symbol}) {rest.lstrip()}"
match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\)(.*)$", content)
if match:
symbol, sep, rest = match.groups()
sep = ")" if sep == "" else sep
content = f"{symbol}{sep} {rest.lstrip()}"
# Single newline → double newline
content = re.sub(r"(?<!\n)\n(?!\n)", "\n\n", content)
return content
# ------------------------------------------------------------------ #
# Structural merges
# ------------------------------------------------------------------ #
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
"""Merge formula_number region into adjacent formula with \\tag{}."""
if not items:
return items
merged: List[Dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
native = block.get("native_label", "")
# Case 1: formula_number then formula
if native == "formula_number":
if i + 1 < len(items) and items[i + 1].get("label") == "formula":
num_clean = clean_formula_number(block.get("content", "").strip())
formula_content = items[i + 1].get("content", "")
merged_block = deepcopy(items[i + 1])
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue # always skip the formula_number block itself
# Case 2: formula then formula_number
if block.get("label") == "formula":
if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number":
num_clean = clean_formula_number(items[i + 1].get("content", "").strip())
formula_content = block.get("content", "")
merged_block = deepcopy(block)
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
if not items or not _WORDFREQ_AVAILABLE:
return items
merged: List[Dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
if block.get("label") != "text":
merged.append(block)
continue
content = block.get("content", "")
if not isinstance(content, str) or not content.rstrip().endswith("-"):
merged.append(block)
continue
content_stripped = content.rstrip()
did_merge = False
for j in range(i + 1, len(items)):
if items[j].get("label") != "text":
continue
next_content = items[j].get("content", "")
if not isinstance(next_content, str):
continue
next_stripped = next_content.lstrip()
if next_stripped and next_stripped[0].islower():
words_before = content_stripped[:-1].split()
next_words = next_stripped.split()
if words_before and next_words:
merged_word = words_before[-1] + next_words[0]
if zipf_frequency(merged_word.lower(), "en") >= 2.5:
merged_block = deepcopy(block)
merged_block["content"] = content_stripped[:-1] + next_content.lstrip()
merged.append(merged_block)
skip.add(j)
did_merge = True
break
if not did_merge:
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
if len(items) < 3:
return items
for i in range(1, len(items) - 1):
cur = items[i]
prev = items[i - 1]
nxt = items[i + 1]
if cur.get("native_label") != "text":
continue
if prev.get("native_label") != "text" or nxt.get("native_label") != "text":
continue
cur_content = cur.get("content", "")
if cur_content.startswith("- "):
continue
prev_content = prev.get("content", "")
nxt_content = nxt.get("content", "")
if not (prev_content.startswith("- ") and nxt_content.startswith("- ")):
continue
cur_bbox = cur.get("bbox_2d", [])
prev_bbox = prev.get("bbox_2d", [])
nxt_bbox = nxt.get("bbox_2d", [])
if not (cur_bbox and prev_bbox and nxt_bbox):
continue
if (
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
):
cur["content"] = "- " + cur_content
return items

View File

@@ -1,9 +1,10 @@
"""PP-DocLayoutV2 wrapper for document layout detection."""
"""PP-DocLayoutV3 wrapper for document layout detection."""
import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
from app.core.config import get_settings
from app.services.layout_postprocess import apply_layout_postprocess
from paddleocr import LayoutDetection
from typing import Optional
@@ -116,6 +117,17 @@ class LayoutDetector:
else:
boxes = []
# Apply GLM-OCR layout post-processing (NMS, containment, unclip, clamp)
if boxes:
h, w = image.shape[:2]
boxes = apply_layout_postprocess(
boxes,
img_size=(w, h),
layout_nms=True,
layout_unclip_ratio=None,
layout_merge_bboxes_mode="large",
)
for box in boxes:
cls_id = box.get("cls_id")
label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other")
@@ -128,6 +140,7 @@ class LayoutDetector:
regions.append(
LayoutRegion(
type=region_type,
native_label=label,
bbox=coordinate,
confidence=score,
score=score,

View File

@@ -0,0 +1,343 @@
"""Layout post-processing utilities ported from GLM-OCR.
Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py
Algorithms applied after PaddleOCR LayoutDetection.predict():
1. NMS with dual IoU thresholds (same-class vs cross-class)
2. Large-image-region filtering (remove image boxes that fill most of the page)
3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child)
4. Unclip ratio (optional bbox expansion)
5. Invalid bbox skipping
These steps run on top of PaddleOCR's built-in detection to replicate
the quality of the GLM-OCR SDK's layout pipeline.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
# ---------------------------------------------------------------------------
# Primitive geometry helpers
# ---------------------------------------------------------------------------
def iou(box1: List[float], box2: List[float]) -> float:
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
x1, y1, x2, y2 = box1
x1_p, y1_p, x2_p, y2_p = box2
x1_i = max(x1, x1_p)
y1_i = max(y1, y1_p)
x2_i = min(x2, x2_p)
y2_i = min(y2, y2_p)
inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
return inter_area / float(box1_area + box2_area - inter_area)
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
box format: [cls_id, score, x1, y1, x2, y2]
"""
_, _, x1, y1, x2, y2 = box1
_, _, x1_p, y1_p, x2_p, y2_p = box2
box1_area = (x2 - x1) * (y2 - y1)
if box1_area <= 0:
return False
xi1 = max(x1, x1_p)
yi1 = max(y1, y1_p)
xi2 = min(x2, x2_p)
yi2 = min(y2, y2_p)
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
return (inter_area / box1_area) >= overlap_threshold
# ---------------------------------------------------------------------------
# NMS
# ---------------------------------------------------------------------------
def nms(
boxes: np.ndarray,
iou_same: float = 0.6,
iou_diff: float = 0.98,
) -> List[int]:
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
iou_same: Suppression threshold for boxes of the same class.
iou_diff: Suppression threshold for boxes of different classes.
Returns:
List of kept row indices.
"""
scores = boxes[:, 1]
indices = np.argsort(scores)[::-1].tolist()
selected: List[int] = []
while indices:
current = indices[0]
selected.append(current)
current_class = int(boxes[current, 0])
current_coords = boxes[current, 2:6].tolist()
indices = indices[1:]
kept = []
for i in indices:
box_class = int(boxes[i, 0])
box_coords = boxes[i, 2:6].tolist()
threshold = iou_same if current_class == box_class else iou_diff
if iou(current_coords, box_coords) < threshold:
kept.append(i)
indices = kept
return selected
# ---------------------------------------------------------------------------
# Containment analysis
# ---------------------------------------------------------------------------
# Labels whose regions should never be removed even when contained in another box
_PRESERVE_LABELS = {"image", "seal", "chart"}
def check_containment(
boxes: np.ndarray,
preserve_cls_ids: Optional[set] = None,
category_index: Optional[int] = None,
mode: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute containment flags for each box.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
preserve_cls_ids: Class IDs that must never be marked as contained.
category_index: If set, apply mode only relative to this class.
mode: 'large' or 'small' (only used with category_index).
Returns:
(contains_other, contained_by_other): boolean arrays of length N.
"""
n = len(boxes)
contains_other = np.zeros(n, dtype=int)
contained_by_other = np.zeros(n, dtype=int)
for i in range(n):
for j in range(n):
if i == j:
continue
if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids:
continue
if category_index is not None and mode is not None:
if mode == "large" and int(boxes[j, 0]) == category_index:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
elif mode == "small" and int(boxes[i, 0]) == category_index:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
else:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
return contains_other, contained_by_other
# ---------------------------------------------------------------------------
# Box expansion (unclip)
# ---------------------------------------------------------------------------
def unclip_boxes(
boxes: np.ndarray,
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
) -> np.ndarray:
"""Expand bounding boxes by the given ratio.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio.
Returns:
Expanded boxes array.
"""
if unclip_ratio is None:
return boxes
if isinstance(unclip_ratio, dict):
expanded = []
for box in boxes:
cls_id = int(box[0])
if cls_id in unclip_ratio:
w_ratio, h_ratio = unclip_ratio[cls_id]
x1, y1, x2, y2 = box[2], box[3], box[4], box[5]
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio
new_box = list(box)
new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2
new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2
expanded.append(new_box)
else:
expanded.append(list(box))
return np.array(expanded)
# Scalar or tuple
if isinstance(unclip_ratio, (int, float)):
unclip_ratio = (float(unclip_ratio), float(unclip_ratio))
w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1]
widths = boxes[:, 4] - boxes[:, 2]
heights = boxes[:, 5] - boxes[:, 3]
cx = boxes[:, 2] + widths / 2
cy = boxes[:, 3] + heights / 2
nw, nh = widths * w_ratio, heights * h_ratio
expanded = boxes.copy().astype(float)
expanded[:, 2] = cx - nw / 2
expanded[:, 3] = cy - nh / 2
expanded[:, 4] = cx + nw / 2
expanded[:, 5] = cy + nh / 2
return expanded
# ---------------------------------------------------------------------------
# Main entry-point
# ---------------------------------------------------------------------------
def apply_layout_postprocess(
boxes: List[Dict],
img_size: Tuple[int, int],
layout_nms: bool = True,
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
) -> List[Dict]:
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
Args:
boxes: PaddleOCR output — list of dicts with keys:
cls_id, label, score, coordinate ([x1, y1, x2, y2]).
img_size: (width, height) of the image.
layout_nms: Apply dual-threshold NMS.
layout_unclip_ratio: Optional bbox expansion ratio.
layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small',
'union', or per-class dict.
Returns:
Filtered and ordered list of box dicts in the same PaddleOCR format.
"""
if not boxes:
return boxes
img_width, img_height = img_size
# --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- #
arr_rows = []
for b in boxes:
cls_id = b.get("cls_id", 0)
score = b.get("score", 0.0)
x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0])
arr_rows.append([cls_id, score, x1, y1, x2, y2])
boxes_array = np.array(arr_rows, dtype=float)
all_labels: List[str] = [b.get("label", "") for b in boxes]
# 1. NMS ---------------------------------------------------------------- #
if layout_nms and len(boxes_array) > 1:
kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98)
boxes_array = boxes_array[kept]
all_labels = [all_labels[k] for k in kept]
# 2. Filter large image regions ---------------------------------------- #
if len(boxes_array) > 1:
img_area = img_width * img_height
area_thres = 0.82 if img_width > img_height else 0.93
image_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl == "image"
}
keep_mask = np.ones(len(boxes_array), dtype=bool)
for i, lbl in enumerate(all_labels):
if lbl == "image":
x1, y1, x2, y2 = boxes_array[i, 2:6]
x1 = max(0.0, x1); y1 = max(0.0, y1)
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
keep_mask[i] = False
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
if layout_merge_bboxes_mode and len(boxes_array) > 1:
preserve_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl in _PRESERVE_LABELS
}
if isinstance(layout_merge_bboxes_mode, str):
mode = layout_merge_bboxes_mode
if mode in ("large", "small"):
contains_other, contained_by_other = check_containment(
boxes_array, preserve_cls_ids
)
if mode == "large":
keep_mask = contained_by_other == 0
else:
keep_mask = (contains_other == 0) | (contained_by_other == 1)
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
elif isinstance(layout_merge_bboxes_mode, dict):
keep_mask = np.ones(len(boxes_array), dtype=bool)
for category_index, mode in layout_merge_bboxes_mode.items():
if mode in ("large", "small"):
contains_other, contained_by_other = check_containment(
boxes_array, preserve_cls_ids, int(category_index), mode
)
if mode == "large":
keep_mask &= contained_by_other == 0
else:
keep_mask &= (contains_other == 0) | (contained_by_other == 1)
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
if len(boxes_array) == 0:
return []
# 4. Unclip (bbox expansion) ------------------------------------------- #
if layout_unclip_ratio is not None:
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
# 5. Clamp to image boundaries + skip invalid -------------------------- #
result: List[Dict] = []
for i, row in enumerate(boxes_array):
cls_id = int(row[0])
score = float(row[1])
x1 = max(0.0, min(float(row[2]), img_width))
y1 = max(0.0, min(float(row[3]), img_height))
x2 = max(0.0, min(float(row[4]), img_width))
y2 = max(0.0, min(float(row[5]), img_height))
if x1 >= x2 or y1 >= y2:
continue
result.append({
"cls_id": cls_id,
"label": all_labels[i],
"score": score,
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
})
return result

View File

@@ -1,19 +1,23 @@
"""PaddleOCR-VL client service for text and formula recognition."""
import re
import numpy as np
import cv2
import requests
from io import BytesIO
import base64
from app.core.config import get_settings
from paddleocr import PaddleOCRVL
from typing import Optional
from app.services.layout_detector import LayoutDetector
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
import re
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
import cv2
import numpy as np
import requests
from openai import OpenAI
from paddleocr import PaddleOCRVL
from PIL import Image as PILImage
from app.core.config import get_settings
from app.services.converter import Converter
from app.services.glm_postprocess import GLMResultFormatter
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
settings = get_settings()
@@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
# Strategy: remove spaces before \ and between non-command chars,
# but preserve the space after \command when followed by a non-\ char
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
cleaned = re.sub(
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
) # remove space after non-letter non-\
return f"{operator}{{{cleaned}}}"
# Match _{ ... } or ^{ ... }
@@ -383,8 +389,8 @@ class OCRServiceBase(ABC):
class OCRService(OCRServiceBase):
"""Service for OCR using PaddleOCR-VL."""
_pipeline: Optional[PaddleOCRVL] = None
_layout_detector: Optional[LayoutDetector] = None
_pipeline: PaddleOCRVL | None = None
_layout_detector: LayoutDetector | None = None
def __init__(
self,
@@ -549,7 +555,15 @@ class GLMOCRService(OCRServiceBase):
# Call OpenAI-compatible API with formula recognition prompt
prompt = "Formula Recognition:"
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
# Don't catch exceptions here - let them propagate for fallback handling
response = self.openai_client.chat.completions.create(
@@ -596,10 +610,10 @@ class MineruOCRService(OCRServiceBase):
def __init__(
self,
api_url: str = "http://127.0.0.1:8000/file_parse",
image_processor: Optional[ImageProcessor] = None,
converter: Optional[Converter] = None,
image_processor: ImageProcessor | None = None,
converter: Converter | None = None,
glm_ocr_url: str = "http://localhost:8002/v1",
layout_detector: Optional[LayoutDetector] = None,
layout_detector: LayoutDetector | None = None,
):
"""Initialize Local API service.
@@ -614,7 +628,9 @@ class MineruOCRService(OCRServiceBase):
self.glm_ocr_url = glm_ocr_url
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
def _recognize_formula_with_paddleocr_vl(
self, image: np.ndarray, prompt: str = "Formula Recognition:"
) -> str:
"""Recognize formula using PaddleOCR-VL API.
Args:
@@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase):
image_url = f"data:image/png;base64,{image_base64}"
# Call OpenAI-compatible API
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
@@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase):
except Exception as e:
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
def _extract_and_recognize_formulas(
self, markdown_content: str, original_image: np.ndarray
) -> str:
"""Extract image references from markdown and recognize formulas.
Args:
@@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase):
}
# Make API request
response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30)
response = requests.post(
self.api_url,
files=files,
data=data,
headers={"accept": "application/json"},
timeout=30,
)
response.raise_for_status()
result = response.json()
@@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase):
markdown_content = result["results"]["image"].get("md_content", "")
if "![](images/" in markdown_content:
markdown_content = self._extract_and_recognize_formulas(markdown_content, original_image)
markdown_content = self._extract_and_recognize_formulas(
markdown_content, original_image
)
# Apply postprocessing to fix OCR errors
markdown_content = _postprocess_markdown(markdown_content)
@@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase):
raise RuntimeError(f"Recognition failed: {e}") from e
# Task-specific prompts (from GLM-OCR SDK config.yaml)
_TASK_PROMPTS: dict[str, str] = {
"text": "Text Recognition:",
"formula": "Formula Recognition:",
"table": "Table Recognition:",
}
_DEFAULT_PROMPT = (
"Recognize the text in the image and output in Markdown format. "
"Preserve the original layout (headings/paragraphs/tables/formulas). "
"Do not fabricate content that does not exist in the image."
)
class GLMOCREndToEndService(OCRServiceBase):
"""End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR.
Pipeline:
1. Add padding (ImageProcessor)
2. Detect layout regions (LayoutDetector → PP-DocLayoutV3)
3. Crop each region and call vLLM with a task-specific prompt (parallel)
4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags
5. _postprocess_markdown: LaTeX math error correction
6. Converter: markdown → latex/mathml/mml
This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed).
"""
def __init__(
self,
vl_server_url: str,
image_processor: ImageProcessor,
converter: Converter,
layout_detector: LayoutDetector,
max_workers: int = 8,
):
self.vl_server_url = vl_server_url or settings.glm_ocr_url
self.image_processor = image_processor
self.converter = converter
self.layout_detector = layout_detector
self.max_workers = max_workers
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
self._formatter = GLMResultFormatter()
def _encode_region(self, image: np.ndarray) -> str:
"""Convert BGR numpy array to base64 JPEG string."""
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_img = PILImage.fromarray(rgb)
buf = BytesIO()
pil_img.save(buf, format="JPEG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _call_vllm(self, image: np.ndarray, prompt: str) -> str:
"""Send image + prompt to vLLM and return raw content string."""
img_b64 = self._encode_region(image)
data_url = f"data:image/jpeg;base64,{img_b64}"
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
messages=messages,
temperature=0.01,
max_tokens=settings.max_tokens,
)
return response.choices[0].message.content.strip()
def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]:
"""Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords."""
x1, y1, x2, y2 = bbox
return [
int(x1 / img_w * 1000),
int(y1 / img_h * 1000),
int(x2 / img_w * 1000),
int(y2 / img_h * 1000),
]
def recognize(self, image: np.ndarray) -> dict:
"""Full pipeline: padding → layout → per-region OCR → postprocess → markdown.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
"""
# 1. Padding
padded = self.image_processor.add_padding(image)
img_h, img_w = padded.shape[:2]
# 2. Layout detection
layout_info = self.layout_detector.detect(padded)
# 3. OCR: per-region (parallel) or full-image fallback
if not layout_info.regions:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Build task list for non-figure regions
tasks = []
for idx, region in enumerate(layout_info.regions):
if region.type == "figure":
continue
x1, y1, x2, y2 = (int(c) for c in region.bbox)
cropped = padded[y1:y2, x1:x2]
if cropped.size == 0:
continue
prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT)
tasks.append((idx, region, cropped, prompt))
if not tasks:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Parallel OCR calls
raw_results: dict[int, str] = {}
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
future_map = {
ex.submit(self._call_vllm, cropped, prompt): idx
for idx, region, cropped, prompt in tasks
}
for future in as_completed(future_map):
idx = future_map[future]
try:
raw_results[idx] = future.result()
except Exception:
raw_results[idx] = ""
# Build structured region dicts for GLMResultFormatter
region_dicts = []
for idx, region, _cropped, _prompt in tasks:
region_dicts.append(
{
"index": idx,
"label": region.type,
"native_label": region.native_label,
"content": raw_results.get(idx, ""),
"bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h),
}
)
# 4. GLM-OCR postprocessing: clean, format, merge, bullets
markdown_content = self._formatter.process(region_dicts)
# 5. LaTeX math error correction (our existing pipeline)
markdown_content = _postprocess_markdown(markdown_content)
# 6. Format conversion
latex, mathml, mml = "", "", ""
if markdown_content and self.converter:
fmt = self.converter.convert_to_formats(markdown_content)
latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/formula2.jpg")

View File

@@ -29,6 +29,7 @@ dependencies = [
"safetensors",
"lxml>=5.0.0",
"openai",
"wordfreq",
]
# [tool.uv.sources]

View File

@@ -0,0 +1,98 @@
import numpy as np
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.api.v1.endpoints.image import router
from app.core.dependencies import get_glmocr_endtoend_service, get_image_processor
class _FakeImageProcessor:
def preprocess(self, image_url=None, image_base64=None):
return np.zeros((8, 8, 3), dtype=np.uint8)
class _FakeOCRService:
def __init__(self, result=None, error=None):
self._result = result or {"markdown": "md", "latex": "tex", "mathml": "mml", "mml": "xml"}
self._error = error
def recognize(self, image):
if self._error:
raise self._error
return self._result
def _build_client(image_processor=None, ocr_service=None):
app = FastAPI()
app.include_router(router)
app.dependency_overrides[get_image_processor] = lambda: image_processor or _FakeImageProcessor()
app.dependency_overrides[get_glmocr_endtoend_service] = lambda: ocr_service or _FakeOCRService()
return TestClient(app)
def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
client = _build_client()
missing = client.post("/ocr", json={})
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
assert missing.status_code == 422
assert both.status_code == 422
def test_image_endpoint_returns_503_for_runtime_error():
client = _build_client(ocr_service=_FakeOCRService(error=RuntimeError("backend unavailable")))
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
assert response.status_code == 503
assert response.json()["detail"] == "backend unavailable"
def test_image_endpoint_returns_500_for_unexpected_error():
client = _build_client(ocr_service=_FakeOCRService(error=ValueError("boom")))
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
assert response.status_code == 500
assert response.json()["detail"] == "Internal server error"
def test_image_endpoint_returns_ocr_payload():
client = _build_client()
response = client.post("/ocr", json={"image_base64": "ZmFrZQ=="})
assert response.status_code == 200
assert response.json() == {
"latex": "tex",
"markdown": "md",
"mathml": "mml",
"mml": "xml",
"layout_info": {"regions": [], "MixedRecognition": False},
"recognition_mode": "",
}
def test_image_endpoint_real_e2e_with_env_services():
from app.main import app
image_url = (
"https://static.texpixel.com/formula/012dab3e-fb31-4ecd-90fc-6957458ee309.png"
"?Expires=1773049821&OSSAccessKeyId=TMP.3KnrJUz7aXHoU9rLTAih4MAyPGd9zyGRHiqg9AyH6TY6NKtzqT2yr4qo7Vwf8fMRFCBrWXiCFrbBwC3vn7U6mspV2NeU1K"
"&Signature=oynhP0OLIgFI0Sv3z2CWeHPT2Ck%3D"
)
with TestClient(app) as client:
response = client.post(
"/doc_process/v1/image/ocr",
json={"image_url": image_url},
headers={"x-request-id": "test-e2e"},
)
assert response.status_code == 200, response.text
payload = response.json()
assert isinstance(payload["markdown"], str)
assert payload["markdown"].strip()
assert set(payload) >= {"markdown", "latex", "mathml", "mml"}

View File

@@ -0,0 +1,10 @@
import pytest
from app.core import dependencies
def test_get_glmocr_endtoend_service_raises_when_layout_detector_missing(monkeypatch):
monkeypatch.setattr(dependencies, "_layout_detector", None)
with pytest.raises(RuntimeError, match="Layout detector not initialized"):
dependencies.get_glmocr_endtoend_service()

View File

@@ -0,0 +1,31 @@
from app.schemas.image import ImageOCRRequest, LayoutRegion
def test_layout_region_native_label_defaults_to_empty_string():
region = LayoutRegion(
type="text",
bbox=[0, 0, 10, 10],
confidence=0.9,
score=0.9,
)
assert region.native_label == ""
def test_layout_region_exposes_native_label_when_provided():
region = LayoutRegion(
type="text",
native_label="doc_title",
bbox=[0, 0, 10, 10],
confidence=0.9,
score=0.9,
)
assert region.native_label == "doc_title"
def test_image_ocr_request_requires_exactly_one_input():
request = ImageOCRRequest(image_url="https://example.com/test.png")
assert request.image_url == "https://example.com/test.png"
assert request.image_base64 is None

View File

@@ -0,0 +1,199 @@
from app.services.glm_postprocess import (
GLMResultFormatter,
clean_formula_number,
clean_repeated_content,
find_consecutive_repeat,
)
def test_find_consecutive_repeat_truncates_when_threshold_met():
repeated = "abcdefghij" * 10 + "tail"
assert find_consecutive_repeat(repeated) == "abcdefghij"
def test_find_consecutive_repeat_returns_none_when_below_threshold():
assert find_consecutive_repeat("abcdefghij" * 9) is None
def test_clean_repeated_content_handles_consecutive_and_line_level_repeats():
assert clean_repeated_content("abcdefghij" * 10 + "tail") == "abcdefghij"
line_repeated = "\n".join(["same line"] * 10 + ["other"])
assert clean_repeated_content(line_repeated, line_threshold=10) == "same line\n"
assert clean_repeated_content("normal text") == "normal text"
def test_clean_formula_number_strips_wrapping_parentheses():
assert clean_formula_number("(1)") == "1"
assert clean_formula_number("2.1") == "2.1"
assert clean_formula_number("3") == "3"
def test_clean_content_removes_literal_tabs_and_long_repeat_noise():
formatter = GLMResultFormatter()
noisy = r"\t\t" + ("·" * 5) + ("abcdefghij" * 205) + r"\t"
cleaned = formatter._clean_content(noisy)
assert cleaned.startswith("···")
assert cleaned.endswith("abcdefghij")
assert r"\t" not in cleaned
def test_format_content_handles_titles_formula_text_and_newlines():
formatter = GLMResultFormatter()
assert formatter._format_content("Intro", "text", "doc_title") == "# Intro"
assert formatter._format_content("- Section", "text", "paragraph_title") == "## Section"
assert formatter._format_content(r"\[x+y\]", "formula", "display_formula") == "$$\nx+y\n$$"
assert formatter._format_content("· item\nnext", "text", "text") == "- item\n\nnext"
def test_merge_formula_numbers_merges_before_and_after_formula():
formatter = GLMResultFormatter()
before = formatter._merge_formula_numbers(
[
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
]
)
after = formatter._merge_formula_numbers(
[
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
]
)
untouched = formatter._merge_formula_numbers(
[{"index": 0, "label": "text", "native_label": "formula_number", "content": "(3)"}]
)
assert before == [
{
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y \\tag{1}\n$$",
}
]
assert after == [
{
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y \\tag{2}\n$$",
}
]
assert untouched == []
def test_merge_text_blocks_joins_hyphenated_words_when_wordfreq_accepts(monkeypatch):
formatter = GLMResultFormatter()
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 3.0)
merged = formatter._merge_text_blocks(
[
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "national"},
]
)
assert merged == [
{"index": 0, "label": "text", "native_label": "text", "content": "international"}
]
def test_merge_text_blocks_skips_invalid_merge(monkeypatch):
formatter = GLMResultFormatter()
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 1.0)
merged = formatter._merge_text_blocks(
[
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
]
)
assert merged == [
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
]
def test_format_bullet_points_infers_missing_middle_bullet():
formatter = GLMResultFormatter()
items = [
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
{"native_label": "text", "content": "second", "bbox_2d": [12, 12, 52, 22]},
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
]
formatted = formatter._format_bullet_points(items)
assert formatted[1]["content"] == "- second"
def test_format_bullet_points_skips_when_bbox_missing():
formatter = GLMResultFormatter()
items = [
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
{"native_label": "text", "content": "second", "bbox_2d": []},
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
]
formatted = formatter._format_bullet_points(items)
assert formatted[1]["content"] == "second"
def test_process_runs_full_pipeline_and_skips_empty_content():
formatter = GLMResultFormatter()
regions = [
{
"index": 0,
"label": "text",
"native_label": "doc_title",
"content": "Doc Title",
"bbox_2d": [0, 0, 100, 30],
},
{
"index": 1,
"label": "text",
"native_label": "formula_number",
"content": "(1)",
"bbox_2d": [80, 50, 100, 60],
},
{
"index": 2,
"label": "formula",
"native_label": "display_formula",
"content": "x+y",
"bbox_2d": [0, 40, 100, 80],
},
{
"index": 3,
"label": "figure",
"native_label": "image",
"content": "figure placeholder",
"bbox_2d": [0, 80, 100, 120],
},
{
"index": 4,
"label": "text",
"native_label": "text",
"content": "",
"bbox_2d": [0, 120, 100, 150],
},
]
output = formatter.process(regions)
assert "# Doc Title" in output
assert "$$\nx+y \\tag{1}\n$$" in output
assert "![](bbox=[0, 80, 100, 120])" in output

View File

@@ -0,0 +1,46 @@
import numpy as np
from app.services.layout_detector import LayoutDetector
class _FakePredictor:
def __init__(self, boxes):
self._boxes = boxes
def predict(self, image):
return [{"boxes": self._boxes}]
def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
raw_boxes = [
{"cls_id": 22, "label": "text", "score": 0.95, "coordinate": [0, 0, 100, 100]},
{"cls_id": 22, "label": "text", "score": 0.90, "coordinate": [10, 10, 20, 20]},
{"cls_id": 6, "label": "doc_title", "score": 0.99, "coordinate": [0, 0, 80, 20]},
]
detector = LayoutDetector.__new__(LayoutDetector)
detector._get_layout_detector = lambda: _FakePredictor(raw_boxes)
calls = {}
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
calls["args"] = {
"boxes": boxes,
"img_size": img_size,
"layout_nms": layout_nms,
"layout_unclip_ratio": layout_unclip_ratio,
"layout_merge_bboxes_mode": layout_merge_bboxes_mode,
}
return [boxes[0], boxes[2]]
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
image = np.zeros((200, 100, 3), dtype=np.uint8)
info = detector.detect(image)
assert calls["args"]["img_size"] == (100, 200)
assert calls["args"]["layout_nms"] is True
assert calls["args"]["layout_merge_bboxes_mode"] == "large"
assert [region.native_label for region in info.regions] == ["text", "doc_title"]
assert [region.type for region in info.regions] == ["text", "text"]
assert info.MixedRecognition is True

View File

@@ -0,0 +1,151 @@
import math
import numpy as np
from app.services.layout_postprocess import (
apply_layout_postprocess,
check_containment,
iou,
is_contained,
nms,
unclip_boxes,
)
def _raw_box(cls_id, score, x1, y1, x2, y2, label="text"):
return {
"cls_id": cls_id,
"label": label,
"score": score,
"coordinate": [x1, y1, x2, y2],
}
def test_iou_handles_full_none_and_partial_overlap():
assert iou([0, 0, 9, 9], [0, 0, 9, 9]) == 1.0
assert iou([0, 0, 9, 9], [20, 20, 29, 29]) == 0.0
assert math.isclose(iou([0, 0, 9, 9], [5, 5, 14, 14]), 1 / 7, rel_tol=1e-6)
def test_nms_keeps_highest_score_for_same_class_overlap():
boxes = np.array(
[
[0, 0.95, 0, 0, 10, 10],
[0, 0.80, 1, 1, 11, 11],
],
dtype=float,
)
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
assert kept == [0]
def test_nms_keeps_cross_class_overlap_boxes_below_diff_threshold():
boxes = np.array(
[
[0, 0.95, 0, 0, 10, 10],
[1, 0.90, 1, 1, 11, 11],
],
dtype=float,
)
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
assert kept == [0, 1]
def test_nms_returns_single_box_index():
boxes = np.array([[0, 0.95, 0, 0, 10, 10]], dtype=float)
assert nms(boxes) == [0]
def test_is_contained_uses_overlap_threshold():
outer = [0, 0.9, 0, 0, 10, 10]
inner = [0, 0.9, 2, 2, 8, 8]
partial = [0, 0.9, 6, 6, 12, 12]
assert is_contained(inner, outer) is True
assert is_contained(partial, outer) is False
assert is_contained(partial, outer, overlap_threshold=0.3) is True
def test_check_containment_respects_preserve_class_ids():
boxes = np.array(
[
[0, 0.9, 0, 0, 100, 100],
[1, 0.8, 10, 10, 30, 30],
[2, 0.7, 15, 15, 25, 25],
],
dtype=float,
)
contains_other, contained_by_other = check_containment(boxes, preserve_cls_ids={1})
assert contains_other.tolist() == [1, 1, 0]
assert contained_by_other.tolist() == [0, 0, 1]
def test_unclip_boxes_supports_scalar_tuple_dict_and_none():
boxes = np.array(
[
[0, 0.9, 10, 10, 20, 20],
[1, 0.8, 30, 30, 50, 40],
],
dtype=float,
)
scalar = unclip_boxes(boxes, 2.0)
assert scalar[:, 2:6].tolist() == [[5.0, 5.0, 25.0, 25.0], [20.0, 25.0, 60.0, 45.0]]
tuple_ratio = unclip_boxes(boxes, (2.0, 3.0))
assert tuple_ratio[:, 2:6].tolist() == [[5.0, 0.0, 25.0, 30.0], [20.0, 20.0, 60.0, 50.0]]
per_class = unclip_boxes(boxes, {1: (1.5, 2.0)})
assert per_class[:, 2:6].tolist() == [[10.0, 10.0, 20.0, 20.0], [25.0, 25.0, 55.0, 45.0]]
assert np.array_equal(unclip_boxes(boxes, None), boxes)
def test_apply_layout_postprocess_large_mode_removes_contained_small_box():
boxes = [
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
_raw_box(0, 0.90, 10, 10, 20, 20, "text"),
]
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
assert [box["coordinate"] for box in result] == [[0, 0, 100, 100]]
def test_apply_layout_postprocess_preserves_contained_image_like_boxes():
boxes = [
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
_raw_box(1, 0.90, 10, 10, 20, 20, "image"),
_raw_box(2, 0.90, 25, 25, 35, 35, "seal"),
_raw_box(3, 0.90, 40, 40, 50, 50, "chart"),
]
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
assert {box["label"] for box in result} == {"text", "image", "seal", "chart"}
def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image():
boxes = [
_raw_box(0, 0.95, -10, -5, 40, 50, "text"),
_raw_box(1, 0.90, 10, 10, 10, 50, "text"),
_raw_box(2, 0.85, 0, 0, 100, 90, "image"),
]
result = apply_layout_postprocess(
boxes,
img_size=(100, 90),
layout_nms=False,
layout_merge_bboxes_mode=None,
)
assert result == [
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
]

View File

@@ -0,0 +1,124 @@
import base64
from types import SimpleNamespace
import cv2
import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
from app.services.ocr_service import GLMOCREndToEndService
class _FakeConverter:
def convert_to_formats(self, markdown):
return SimpleNamespace(
latex=f"LATEX::{markdown}",
mathml=f"MATHML::{markdown}",
mml=f"MML::{markdown}",
)
class _FakeImageProcessor:
def add_padding(self, image):
return image
class _FakeLayoutDetector:
def __init__(self, regions):
self._regions = regions
def detect(self, image):
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
def _build_service(regions=None):
return GLMOCREndToEndService(
vl_server_url="http://127.0.0.1:8002/v1",
image_processor=_FakeImageProcessor(),
converter=_FakeConverter(),
layout_detector=_FakeLayoutDetector(regions or []),
max_workers=2,
)
def test_encode_region_returns_decodable_base64_jpeg():
service = _build_service()
image = np.zeros((8, 12, 3), dtype=np.uint8)
image[:, :] = [0, 128, 255]
encoded = service._encode_region(image)
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
assert decoded.shape[:2] == image.shape[:2]
def test_call_vllm_builds_messages_and_returns_content():
service = _build_service()
captured = {}
def create(**kwargs):
captured.update(kwargs)
return SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
)
service.openai_client = SimpleNamespace(
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
)
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
assert result == "recognized content"
assert captured["model"] == "glm-ocr"
assert captured["max_tokens"] == 1024
assert captured["messages"][0]["content"][0]["type"] == "image_url"
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
def test_normalize_bbox_scales_coordinates_to_1000():
service = _build_service()
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
service = _build_service(regions=[])
image = np.zeros((20, 30, 3), dtype=np.uint8)
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
result = service.recognize(image)
assert result["markdown"] == "raw text"
assert result["latex"] == "LATEX::raw text"
assert result["mathml"] == "MATHML::raw text"
assert result["mml"] == "MML::raw text"
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
regions = [
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
]
service = _build_service(regions=regions)
image = np.zeros((40, 40, 3), dtype=np.uint8)
calls = []
def fake_call_vllm(cropped, prompt):
calls.append(prompt)
if prompt == "Text Recognition:":
return "Title"
return "x + y"
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
result = service.recognize(image)
assert calls == ["Text Recognition:", "Formula Recognition:"]
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"