feat add glm-ocr core

This commit is contained in:
liuyuanchuang
2026-03-09 16:51:06 +08:00
parent d74130914c
commit 6dfaf9668b
17 changed files with 1687 additions and 140 deletions

View File

@@ -2,26 +2,17 @@
import time
import uuid
import cv2
from io import BytesIO
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from app.core.dependencies import (
get_image_processor,
get_layout_detector,
get_ocr_service,
get_mineru_ocr_service,
get_glmocr_service,
get_glmocr_endtoend_service,
)
from app.core.config import get_settings
from app.core.logging_config import get_logger, RequestIDAdapter
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
settings = get_settings()
from app.services.ocr_service import GLMOCREndToEndService
router = APIRouter()
logger = get_logger()
@@ -33,100 +24,38 @@ async def process_image_ocr(
http_request: Request,
response: Response,
image_processor: ImageProcessor = Depends(get_image_processor),
layout_detector: LayoutDetector = Depends(get_layout_detector),
mineru_service: MineruOCRService = Depends(get_mineru_ocr_service),
paddle_service: OCRService = Depends(get_ocr_service),
glmocr_service: GLMOCRService = Depends(get_glmocr_service),
glmocr_service: GLMOCREndToEndService = Depends(get_glmocr_endtoend_service),
) -> ImageOCRResponse:
"""Process an image and extract content as LaTeX, Markdown, and MathML.
The processing pipeline:
1. Load and preprocess image (add 30% whitespace padding)
2. Detect layout using DocLayout-YOLO
3. Based on layout:
- If plain text exists: use PP-DocLayoutV2 for mixed recognition
- Otherwise: use PaddleOCR-VL with formula prompt
4. Convert output to LaTeX, Markdown, and MathML formats
1. Load and preprocess image
2. Detect layout regions using PP-DocLayoutV3
3. Crop each region and recognize with GLM-OCR via vLLM (task-specific prompts)
4. Aggregate region results into Markdown
5. Convert to LaTeX, Markdown, and MathML formats
Note: OMML conversion is not included due to performance overhead.
Use the /convert/latex-to-omml endpoint to convert LaTeX to OMML separately.
"""
# Get or generate request ID
request_id = http_request.headers.get("x-request-id", str(uuid.uuid4()))
response.headers["x-request-id"] = request_id
# Create logger adapter with request_id
log = RequestIDAdapter(logger, {"request_id": request_id})
log.request_id = request_id
try:
log.info("Starting image OCR processing")
start = time.time()
# Preprocess image (load only, no padding yet)
preprocess_start = time.time()
image = image_processor.preprocess(
image_url=request.image_url,
image_base64=request.image_base64,
)
# Apply padding only for layout detection
processed_image = image
if image_processor and settings.is_padding:
processed_image = image_processor.add_padding(image)
ocr_result = glmocr_service.recognize(image)
preprocess_time = time.time() - preprocess_start
log.debug(f"Image loading completed in {preprocess_time:.3f}s")
# Layout detection (using padded image if padding is enabled)
layout_start = time.time()
layout_info = layout_detector.detect(processed_image)
layout_time = time.time() - layout_start
log.info(f"Layout detection completed in {layout_time:.3f}s")
# OCR recognition (use original image without padding)
ocr_start = time.time()
if layout_info.MixedRecognition:
recognition_method = "MixedRecognition (MinerU)"
log.info(f"Using {recognition_method}")
# Convert original image (without padding) to bytes
success, encoded_image = cv2.imencode(".png", image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0) # Ensure position is at the beginning
ocr_result = mineru_service.recognize(image_bytes)
else:
recognition_method = "FormulaOnly (GLMOCR)"
log.info(f"Using {recognition_method}")
# Try GLM-OCR first, fallback to MinerU if token limit exceeded
try:
ocr_result = glmocr_service.recognize(image)
except Exception as e:
error_msg = str(e)
# Check if error is due to token limit (max_model_len exceeded)
if "max_model_len" in error_msg or "decoder prompt" in error_msg or "BadRequestError" in error_msg:
log.warning(f"GLM-OCR failed due to token limit: {error_msg}")
log.info("Falling back to MinerU for recognition")
recognition_method = "FormulaOnly (MinerU fallback)"
# Convert original image to bytes for MinerU
success, encoded_image = cv2.imencode(".png", image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0)
ocr_result = mineru_service.recognize(image_bytes)
else:
# Re-raise other errors
raise
ocr_time = time.time() - ocr_start
total_time = time.time() - preprocess_start
log.info(f"OCR processing completed - Method: {recognition_method}, " f"Layout time: {layout_time:.3f}s, OCR time: {ocr_time:.3f}s, " f"Total time: {total_time:.3f}s")
log.info(f"OCR completed in {time.time() - start:.3f}s")
except RuntimeError as e:
log.error(f"OCR processing failed: {str(e)}", exc_info=True)

View File

@@ -3,9 +3,8 @@
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
import torch
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
@@ -48,21 +47,25 @@ class Settings(BaseSettings):
is_padding: bool = True
padding_ratio: float = 0.1
max_tokens: int = 4096
# Model Paths
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
pp_doclayout_model_dir: str | None = (
"/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
)
# Image Processing
max_image_size_mb: int = 10
image_padding_ratio: float = 0.1 # 10% on each side = 20% total expansion
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # cuda:0 or cpu
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Server Settings
host: str = "0.0.0.0"
port: int = 8053
# Logging Settings
log_dir: Optional[str] = None # Defaults to /app/logs in container or ./logs locally
log_dir: str | None = None # Defaults to /app/logs in container or ./logs locally
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
@property

View File

@@ -2,7 +2,7 @@
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
from app.services.ocr_service import GLMOCREndToEndService
from app.services.converter import Converter
from app.core.config import get_settings
@@ -31,40 +31,17 @@ def get_image_processor() -> ImageProcessor:
return ImageProcessor()
def get_ocr_service() -> OCRService:
"""Get an OCR service instance."""
return OCRService(
vl_server_url=get_settings().paddleocr_vl_url,
layout_detector=get_layout_detector(),
image_processor=get_image_processor(),
converter=get_converter(),
)
def get_converter() -> Converter:
"""Get a DOCX converter instance."""
return Converter()
def get_mineru_ocr_service() -> MineruOCRService:
"""Get a MinerOCR service instance."""
def get_glmocr_endtoend_service() -> GLMOCREndToEndService:
"""Get end-to-end GLM-OCR service (layout detection + per-region OCR)."""
settings = get_settings()
api_url = getattr(settings, "miner_ocr_api_url", "http://127.0.0.1:8000/file_parse")
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://localhost:8002/v1")
return MineruOCRService(
api_url=api_url,
converter=get_converter(),
image_processor=get_image_processor(),
glm_ocr_url=glm_ocr_url,
)
def get_glmocr_service() -> GLMOCRService:
"""Get a GLM OCR service instance."""
settings = get_settings()
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://127.0.0.1:8002/v1")
return GLMOCRService(
vl_server_url=glm_ocr_url,
return GLMOCREndToEndService(
vl_server_url=settings.glm_ocr_url,
image_processor=get_image_processor(),
converter=get_converter(),
layout_detector=get_layout_detector(),
)

View File

@@ -7,6 +7,7 @@ class LayoutRegion(BaseModel):
"""A detected layout region in the document."""
type: str = Field(..., description="Region type: text, formula, table, figure")
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
confidence: float = Field(..., description="Detection confidence score")
score: float = Field(..., description="Detection score")

View File

@@ -0,0 +1,412 @@
"""GLM-OCR postprocessing logic adapted for this project.
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
glm-ocr/glmocr/utils/result_postprocess_utils.py.
Covers:
- Repeated-content / hallucination detection
- Per-region content cleaning and formatting (titles, bullets, formulas)
- formula_number merging (→ \\tag{})
- Hyphenated text-block merging (via wordfreq)
- Missing bullet-point detection
"""
from __future__ import annotations
import re
import json
from collections import Counter
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
try:
from wordfreq import zipf_frequency
_WORDFREQ_AVAILABLE = True
except ImportError:
_WORDFREQ_AVAILABLE = False
# ---------------------------------------------------------------------------
# result_postprocess_utils (ported)
# ---------------------------------------------------------------------------
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
"""Detect and truncate a consecutively-repeated pattern.
Returns the string with the repeat removed, or None if not found.
"""
n = len(s)
if n < min_unit_len * min_repeats:
return None
max_unit_len = n // min_repeats
if max_unit_len < min_unit_len:
return None
pattern = re.compile(
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
re.DOTALL,
)
match = pattern.search(s)
if match:
return s[: match.start()] + match.group(1)
return None
def clean_repeated_content(
content: str,
min_len: int = 10,
min_repeats: int = 10,
line_threshold: int = 10,
) -> str:
"""Remove hallucination-style repeated content (consecutive or line-level)."""
stripped = content.strip()
if not stripped:
return content
# 1. Consecutive repeat (multi-line aware)
if len(stripped) > min_len * min_repeats:
result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats)
if result is not None:
return result
# 2. Line-level repeat
lines = [line.strip() for line in content.split("\n") if line.strip()]
total_lines = len(lines)
if total_lines >= line_threshold and lines:
common, count = Counter(lines).most_common(1)[0]
if count >= line_threshold and (count / total_lines) >= 0.8:
for i, line in enumerate(lines):
if line == common:
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
if consecutive >= 3:
original_lines = content.split("\n")
non_empty_count = 0
for idx, orig_line in enumerate(original_lines):
if orig_line.strip():
non_empty_count += 1
if non_empty_count == i + 1:
return "\n".join(original_lines[: idx + 1])
break
return content
def clean_formula_number(number_content: str) -> str:
"""Strip parentheses from a formula number string, e.g. '(1)''1'."""
s = number_content.strip()
if s.startswith("(") and s.endswith(")"):
return s[1:-1]
if s.startswith("") and s.endswith(""):
return s[1:-1]
return s
# ---------------------------------------------------------------------------
# GLMResultFormatter
# ---------------------------------------------------------------------------
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
_LABEL_TO_CATEGORY: Dict[str, str] = {
# text
"abstract": "text",
"algorithm": "text",
"content": "text",
"doc_title": "text",
"figure_title": "text",
"paragraph_title": "text",
"reference_content": "text",
"text": "text",
"vertical_text": "text",
"vision_footnote": "text",
"seal": "text",
"formula_number": "text",
# table
"table": "table",
# formula
"display_formula": "formula",
"inline_formula": "formula",
# image (skip OCR)
"chart": "image",
"image": "image",
}
class GLMResultFormatter:
"""Port of GLM-OCR's ResultFormatter for use in our pipeline.
Accepts a list of region dicts (each with label, native_label, content,
bbox_2d) and returns a final Markdown string.
"""
# ------------------------------------------------------------------ #
# Public entry-point
# ------------------------------------------------------------------ #
def process(self, regions: List[Dict[str, Any]]) -> str:
"""Run the full postprocessing pipeline and return Markdown.
Args:
regions: List of dicts with keys:
- index (int) reading order from layout detection
- label (str) mapped category: text/formula/table/figure
- native_label (str) raw PP-DocLayout label (e.g. doc_title)
- content (str) raw OCR output from vLLM
- bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords
Returns:
Markdown string.
"""
# Sort by reading order
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
# Per-region cleaning + formatting
processed: List[Dict] = []
for item in items:
item["native_label"] = item.get("native_label", item.get("label", "text"))
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
item["content"] = self._format_content(
item.get("content") or "",
item["label"],
item["native_label"],
)
if not (item.get("content") or "").strip():
continue
processed.append(item)
# Re-index
for i, item in enumerate(processed):
item["index"] = i
# Structural merges
processed = self._merge_formula_numbers(processed)
processed = self._merge_text_blocks(processed)
processed = self._format_bullet_points(processed)
# Assemble Markdown
parts: List[str] = []
for item in processed:
content = item.get("content") or ""
if item["label"] == "image":
parts.append(f"![](bbox={item.get('bbox_2d', [])})")
elif content.strip():
parts.append(content)
return "\n\n".join(parts)
# ------------------------------------------------------------------ #
# Label mapping
# ------------------------------------------------------------------ #
def _map_label(self, label: str, native_label: str) -> str:
return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text"))
# ------------------------------------------------------------------ #
# Content cleaning
# ------------------------------------------------------------------ #
def _clean_content(self, content: str) -> str:
"""Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats."""
if content is None:
return ""
content = re.sub(r"^(\\t)+", "", content).lstrip()
content = re.sub(r"(\\t)+$", "", content).rstrip()
content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(·)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(_)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content)
if len(content) >= 2048:
content = clean_repeated_content(content)
return content.strip()
# ------------------------------------------------------------------ #
# Per-region content formatting
# ------------------------------------------------------------------ #
def _format_content(self, content: Any, label: str, native_label: str) -> str:
"""Clean and format a single region's content."""
if content is None:
return ""
content = self._clean_content(str(content))
# Heading formatting
if native_label == "doc_title":
content = re.sub(r"^#+\s*", "", content)
content = "# " + content
elif native_label == "paragraph_title":
if content.startswith("- ") or content.startswith("* "):
content = content[2:].lstrip()
content = re.sub(r"^#+\s*", "", content)
content = "## " + content.lstrip()
# Formula wrapping
if label == "formula":
content = content.strip()
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
if content.startswith(s) and content.endswith(e):
content = content[len(s) : -len(e)].strip()
break
content = "$$\n" + content + "\n$$"
# Text formatting
if label == "text":
if content.startswith("·") or content.startswith("") or content.startswith("* "):
content = "- " + content[1:].lstrip()
match = re.match(r"^(\(|\)(\d+|[A-Za-z])(\)|\)(.*)$", content)
if match:
_, symbol, _, rest = match.groups()
content = f"({symbol}) {rest.lstrip()}"
match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\)(.*)$", content)
if match:
symbol, sep, rest = match.groups()
sep = ")" if sep == "" else sep
content = f"{symbol}{sep} {rest.lstrip()}"
# Single newline → double newline
content = re.sub(r"(?<!\n)\n(?!\n)", "\n\n", content)
return content
# ------------------------------------------------------------------ #
# Structural merges
# ------------------------------------------------------------------ #
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
"""Merge formula_number region into adjacent formula with \\tag{}."""
if not items:
return items
merged: List[Dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
native = block.get("native_label", "")
# Case 1: formula_number then formula
if native == "formula_number":
if i + 1 < len(items) and items[i + 1].get("label") == "formula":
num_clean = clean_formula_number(block.get("content", "").strip())
formula_content = items[i + 1].get("content", "")
merged_block = deepcopy(items[i + 1])
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue # always skip the formula_number block itself
# Case 2: formula then formula_number
if block.get("label") == "formula":
if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number":
num_clean = clean_formula_number(items[i + 1].get("content", "").strip())
formula_content = block.get("content", "")
merged_block = deepcopy(block)
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
if not items or not _WORDFREQ_AVAILABLE:
return items
merged: List[Dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
if block.get("label") != "text":
merged.append(block)
continue
content = block.get("content", "")
if not isinstance(content, str) or not content.rstrip().endswith("-"):
merged.append(block)
continue
content_stripped = content.rstrip()
did_merge = False
for j in range(i + 1, len(items)):
if items[j].get("label") != "text":
continue
next_content = items[j].get("content", "")
if not isinstance(next_content, str):
continue
next_stripped = next_content.lstrip()
if next_stripped and next_stripped[0].islower():
words_before = content_stripped[:-1].split()
next_words = next_stripped.split()
if words_before and next_words:
merged_word = words_before[-1] + next_words[0]
if zipf_frequency(merged_word.lower(), "en") >= 2.5:
merged_block = deepcopy(block)
merged_block["content"] = content_stripped[:-1] + next_content.lstrip()
merged.append(merged_block)
skip.add(j)
did_merge = True
break
if not did_merge:
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
if len(items) < 3:
return items
for i in range(1, len(items) - 1):
cur = items[i]
prev = items[i - 1]
nxt = items[i + 1]
if cur.get("native_label") != "text":
continue
if prev.get("native_label") != "text" or nxt.get("native_label") != "text":
continue
cur_content = cur.get("content", "")
if cur_content.startswith("- "):
continue
prev_content = prev.get("content", "")
nxt_content = nxt.get("content", "")
if not (prev_content.startswith("- ") and nxt_content.startswith("- ")):
continue
cur_bbox = cur.get("bbox_2d", [])
prev_bbox = prev.get("bbox_2d", [])
nxt_bbox = nxt.get("bbox_2d", [])
if not (cur_bbox and prev_bbox and nxt_bbox):
continue
if (
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
):
cur["content"] = "- " + cur_content
return items

View File

@@ -1,9 +1,10 @@
"""PP-DocLayoutV2 wrapper for document layout detection."""
"""PP-DocLayoutV3 wrapper for document layout detection."""
import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
from app.core.config import get_settings
from app.services.layout_postprocess import apply_layout_postprocess
from paddleocr import LayoutDetection
from typing import Optional
@@ -116,6 +117,17 @@ class LayoutDetector:
else:
boxes = []
# Apply GLM-OCR layout post-processing (NMS, containment, unclip, clamp)
if boxes:
h, w = image.shape[:2]
boxes = apply_layout_postprocess(
boxes,
img_size=(w, h),
layout_nms=True,
layout_unclip_ratio=None,
layout_merge_bboxes_mode="large",
)
for box in boxes:
cls_id = box.get("cls_id")
label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other")
@@ -128,6 +140,7 @@ class LayoutDetector:
regions.append(
LayoutRegion(
type=region_type,
native_label=label,
bbox=coordinate,
confidence=score,
score=score,

View File

@@ -0,0 +1,343 @@
"""Layout post-processing utilities ported from GLM-OCR.
Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py
Algorithms applied after PaddleOCR LayoutDetection.predict():
1. NMS with dual IoU thresholds (same-class vs cross-class)
2. Large-image-region filtering (remove image boxes that fill most of the page)
3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child)
4. Unclip ratio (optional bbox expansion)
5. Invalid bbox skipping
These steps run on top of PaddleOCR's built-in detection to replicate
the quality of the GLM-OCR SDK's layout pipeline.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
# ---------------------------------------------------------------------------
# Primitive geometry helpers
# ---------------------------------------------------------------------------
def iou(box1: List[float], box2: List[float]) -> float:
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
x1, y1, x2, y2 = box1
x1_p, y1_p, x2_p, y2_p = box2
x1_i = max(x1, x1_p)
y1_i = max(y1, y1_p)
x2_i = min(x2, x2_p)
y2_i = min(y2, y2_p)
inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
return inter_area / float(box1_area + box2_area - inter_area)
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
box format: [cls_id, score, x1, y1, x2, y2]
"""
_, _, x1, y1, x2, y2 = box1
_, _, x1_p, y1_p, x2_p, y2_p = box2
box1_area = (x2 - x1) * (y2 - y1)
if box1_area <= 0:
return False
xi1 = max(x1, x1_p)
yi1 = max(y1, y1_p)
xi2 = min(x2, x2_p)
yi2 = min(y2, y2_p)
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
return (inter_area / box1_area) >= overlap_threshold
# ---------------------------------------------------------------------------
# NMS
# ---------------------------------------------------------------------------
def nms(
boxes: np.ndarray,
iou_same: float = 0.6,
iou_diff: float = 0.98,
) -> List[int]:
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
iou_same: Suppression threshold for boxes of the same class.
iou_diff: Suppression threshold for boxes of different classes.
Returns:
List of kept row indices.
"""
scores = boxes[:, 1]
indices = np.argsort(scores)[::-1].tolist()
selected: List[int] = []
while indices:
current = indices[0]
selected.append(current)
current_class = int(boxes[current, 0])
current_coords = boxes[current, 2:6].tolist()
indices = indices[1:]
kept = []
for i in indices:
box_class = int(boxes[i, 0])
box_coords = boxes[i, 2:6].tolist()
threshold = iou_same if current_class == box_class else iou_diff
if iou(current_coords, box_coords) < threshold:
kept.append(i)
indices = kept
return selected
# ---------------------------------------------------------------------------
# Containment analysis
# ---------------------------------------------------------------------------
# Labels whose regions should never be removed even when contained in another box
_PRESERVE_LABELS = {"image", "seal", "chart"}
def check_containment(
boxes: np.ndarray,
preserve_cls_ids: Optional[set] = None,
category_index: Optional[int] = None,
mode: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute containment flags for each box.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
preserve_cls_ids: Class IDs that must never be marked as contained.
category_index: If set, apply mode only relative to this class.
mode: 'large' or 'small' (only used with category_index).
Returns:
(contains_other, contained_by_other): boolean arrays of length N.
"""
n = len(boxes)
contains_other = np.zeros(n, dtype=int)
contained_by_other = np.zeros(n, dtype=int)
for i in range(n):
for j in range(n):
if i == j:
continue
if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids:
continue
if category_index is not None and mode is not None:
if mode == "large" and int(boxes[j, 0]) == category_index:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
elif mode == "small" and int(boxes[i, 0]) == category_index:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
else:
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
contained_by_other[i] = 1
contains_other[j] = 1
return contains_other, contained_by_other
# ---------------------------------------------------------------------------
# Box expansion (unclip)
# ---------------------------------------------------------------------------
def unclip_boxes(
boxes: np.ndarray,
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
) -> np.ndarray:
"""Expand bounding boxes by the given ratio.
Args:
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio.
Returns:
Expanded boxes array.
"""
if unclip_ratio is None:
return boxes
if isinstance(unclip_ratio, dict):
expanded = []
for box in boxes:
cls_id = int(box[0])
if cls_id in unclip_ratio:
w_ratio, h_ratio = unclip_ratio[cls_id]
x1, y1, x2, y2 = box[2], box[3], box[4], box[5]
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio
new_box = list(box)
new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2
new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2
expanded.append(new_box)
else:
expanded.append(list(box))
return np.array(expanded)
# Scalar or tuple
if isinstance(unclip_ratio, (int, float)):
unclip_ratio = (float(unclip_ratio), float(unclip_ratio))
w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1]
widths = boxes[:, 4] - boxes[:, 2]
heights = boxes[:, 5] - boxes[:, 3]
cx = boxes[:, 2] + widths / 2
cy = boxes[:, 3] + heights / 2
nw, nh = widths * w_ratio, heights * h_ratio
expanded = boxes.copy().astype(float)
expanded[:, 2] = cx - nw / 2
expanded[:, 3] = cy - nh / 2
expanded[:, 4] = cx + nw / 2
expanded[:, 5] = cy + nh / 2
return expanded
# ---------------------------------------------------------------------------
# Main entry-point
# ---------------------------------------------------------------------------
def apply_layout_postprocess(
boxes: List[Dict],
img_size: Tuple[int, int],
layout_nms: bool = True,
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
) -> List[Dict]:
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
Args:
boxes: PaddleOCR output — list of dicts with keys:
cls_id, label, score, coordinate ([x1, y1, x2, y2]).
img_size: (width, height) of the image.
layout_nms: Apply dual-threshold NMS.
layout_unclip_ratio: Optional bbox expansion ratio.
layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small',
'union', or per-class dict.
Returns:
Filtered and ordered list of box dicts in the same PaddleOCR format.
"""
if not boxes:
return boxes
img_width, img_height = img_size
# --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- #
arr_rows = []
for b in boxes:
cls_id = b.get("cls_id", 0)
score = b.get("score", 0.0)
x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0])
arr_rows.append([cls_id, score, x1, y1, x2, y2])
boxes_array = np.array(arr_rows, dtype=float)
all_labels: List[str] = [b.get("label", "") for b in boxes]
# 1. NMS ---------------------------------------------------------------- #
if layout_nms and len(boxes_array) > 1:
kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98)
boxes_array = boxes_array[kept]
all_labels = [all_labels[k] for k in kept]
# 2. Filter large image regions ---------------------------------------- #
if len(boxes_array) > 1:
img_area = img_width * img_height
area_thres = 0.82 if img_width > img_height else 0.93
image_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl == "image"
}
keep_mask = np.ones(len(boxes_array), dtype=bool)
for i, lbl in enumerate(all_labels):
if lbl == "image":
x1, y1, x2, y2 = boxes_array[i, 2:6]
x1 = max(0.0, x1); y1 = max(0.0, y1)
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
keep_mask[i] = False
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
if layout_merge_bboxes_mode and len(boxes_array) > 1:
preserve_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl in _PRESERVE_LABELS
}
if isinstance(layout_merge_bboxes_mode, str):
mode = layout_merge_bboxes_mode
if mode in ("large", "small"):
contains_other, contained_by_other = check_containment(
boxes_array, preserve_cls_ids
)
if mode == "large":
keep_mask = contained_by_other == 0
else:
keep_mask = (contains_other == 0) | (contained_by_other == 1)
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
elif isinstance(layout_merge_bboxes_mode, dict):
keep_mask = np.ones(len(boxes_array), dtype=bool)
for category_index, mode in layout_merge_bboxes_mode.items():
if mode in ("large", "small"):
contains_other, contained_by_other = check_containment(
boxes_array, preserve_cls_ids, int(category_index), mode
)
if mode == "large":
keep_mask &= contained_by_other == 0
else:
keep_mask &= (contains_other == 0) | (contained_by_other == 1)
boxes_array = boxes_array[keep_mask]
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
if len(boxes_array) == 0:
return []
# 4. Unclip (bbox expansion) ------------------------------------------- #
if layout_unclip_ratio is not None:
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
# 5. Clamp to image boundaries + skip invalid -------------------------- #
result: List[Dict] = []
for i, row in enumerate(boxes_array):
cls_id = int(row[0])
score = float(row[1])
x1 = max(0.0, min(float(row[2]), img_width))
y1 = max(0.0, min(float(row[3]), img_height))
x2 = max(0.0, min(float(row[4]), img_width))
y2 = max(0.0, min(float(row[5]), img_height))
if x1 >= x2 or y1 >= y2:
continue
result.append({
"cls_id": cls_id,
"label": all_labels[i],
"score": score,
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
})
return result

View File

@@ -1,19 +1,23 @@
"""PaddleOCR-VL client service for text and formula recognition."""
import re
import numpy as np
import cv2
import requests
from io import BytesIO
import base64
from app.core.config import get_settings
from paddleocr import PaddleOCRVL
from typing import Optional
from app.services.layout_detector import LayoutDetector
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
import re
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
import cv2
import numpy as np
import requests
from openai import OpenAI
from paddleocr import PaddleOCRVL
from PIL import Image as PILImage
from app.core.config import get_settings
from app.services.converter import Converter
from app.services.glm_postprocess import GLMResultFormatter
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
settings = get_settings()
@@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
# Strategy: remove spaces before \ and between non-command chars,
# but preserve the space after \command when followed by a non-\ char
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
cleaned = re.sub(
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
) # remove space after non-letter non-\
return f"{operator}{{{cleaned}}}"
# Match _{ ... } or ^{ ... }
@@ -383,8 +389,8 @@ class OCRServiceBase(ABC):
class OCRService(OCRServiceBase):
"""Service for OCR using PaddleOCR-VL."""
_pipeline: Optional[PaddleOCRVL] = None
_layout_detector: Optional[LayoutDetector] = None
_pipeline: PaddleOCRVL | None = None
_layout_detector: LayoutDetector | None = None
def __init__(
self,
@@ -549,7 +555,15 @@ class GLMOCRService(OCRServiceBase):
# Call OpenAI-compatible API with formula recognition prompt
prompt = "Formula Recognition:"
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
# Don't catch exceptions here - let them propagate for fallback handling
response = self.openai_client.chat.completions.create(
@@ -596,10 +610,10 @@ class MineruOCRService(OCRServiceBase):
def __init__(
self,
api_url: str = "http://127.0.0.1:8000/file_parse",
image_processor: Optional[ImageProcessor] = None,
converter: Optional[Converter] = None,
image_processor: ImageProcessor | None = None,
converter: Converter | None = None,
glm_ocr_url: str = "http://localhost:8002/v1",
layout_detector: Optional[LayoutDetector] = None,
layout_detector: LayoutDetector | None = None,
):
"""Initialize Local API service.
@@ -614,7 +628,9 @@ class MineruOCRService(OCRServiceBase):
self.glm_ocr_url = glm_ocr_url
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
def _recognize_formula_with_paddleocr_vl(
self, image: np.ndarray, prompt: str = "Formula Recognition:"
) -> str:
"""Recognize formula using PaddleOCR-VL API.
Args:
@@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase):
image_url = f"data:image/png;base64,{image_base64}"
# Call OpenAI-compatible API
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
@@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase):
except Exception as e:
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
def _extract_and_recognize_formulas(
self, markdown_content: str, original_image: np.ndarray
) -> str:
"""Extract image references from markdown and recognize formulas.
Args:
@@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase):
}
# Make API request
response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30)
response = requests.post(
self.api_url,
files=files,
data=data,
headers={"accept": "application/json"},
timeout=30,
)
response.raise_for_status()
result = response.json()
@@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase):
markdown_content = result["results"]["image"].get("md_content", "")
if "![](images/" in markdown_content:
markdown_content = self._extract_and_recognize_formulas(markdown_content, original_image)
markdown_content = self._extract_and_recognize_formulas(
markdown_content, original_image
)
# Apply postprocessing to fix OCR errors
markdown_content = _postprocess_markdown(markdown_content)
@@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase):
raise RuntimeError(f"Recognition failed: {e}") from e
# Task-specific prompts (from GLM-OCR SDK config.yaml)
_TASK_PROMPTS: dict[str, str] = {
"text": "Text Recognition:",
"formula": "Formula Recognition:",
"table": "Table Recognition:",
}
_DEFAULT_PROMPT = (
"Recognize the text in the image and output in Markdown format. "
"Preserve the original layout (headings/paragraphs/tables/formulas). "
"Do not fabricate content that does not exist in the image."
)
class GLMOCREndToEndService(OCRServiceBase):
"""End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR.
Pipeline:
1. Add padding (ImageProcessor)
2. Detect layout regions (LayoutDetector → PP-DocLayoutV3)
3. Crop each region and call vLLM with a task-specific prompt (parallel)
4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags
5. _postprocess_markdown: LaTeX math error correction
6. Converter: markdown → latex/mathml/mml
This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed).
"""
def __init__(
self,
vl_server_url: str,
image_processor: ImageProcessor,
converter: Converter,
layout_detector: LayoutDetector,
max_workers: int = 8,
):
self.vl_server_url = vl_server_url or settings.glm_ocr_url
self.image_processor = image_processor
self.converter = converter
self.layout_detector = layout_detector
self.max_workers = max_workers
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
self._formatter = GLMResultFormatter()
def _encode_region(self, image: np.ndarray) -> str:
"""Convert BGR numpy array to base64 JPEG string."""
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_img = PILImage.fromarray(rgb)
buf = BytesIO()
pil_img.save(buf, format="JPEG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _call_vllm(self, image: np.ndarray, prompt: str) -> str:
"""Send image + prompt to vLLM and return raw content string."""
img_b64 = self._encode_region(image)
data_url = f"data:image/jpeg;base64,{img_b64}"
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
messages=messages,
temperature=0.01,
max_tokens=settings.max_tokens,
)
return response.choices[0].message.content.strip()
def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]:
"""Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords."""
x1, y1, x2, y2 = bbox
return [
int(x1 / img_w * 1000),
int(y1 / img_h * 1000),
int(x2 / img_w * 1000),
int(y2 / img_h * 1000),
]
def recognize(self, image: np.ndarray) -> dict:
"""Full pipeline: padding → layout → per-region OCR → postprocess → markdown.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
"""
# 1. Padding
padded = self.image_processor.add_padding(image)
img_h, img_w = padded.shape[:2]
# 2. Layout detection
layout_info = self.layout_detector.detect(padded)
# 3. OCR: per-region (parallel) or full-image fallback
if not layout_info.regions:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Build task list for non-figure regions
tasks = []
for idx, region in enumerate(layout_info.regions):
if region.type == "figure":
continue
x1, y1, x2, y2 = (int(c) for c in region.bbox)
cropped = padded[y1:y2, x1:x2]
if cropped.size == 0:
continue
prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT)
tasks.append((idx, region, cropped, prompt))
if not tasks:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Parallel OCR calls
raw_results: dict[int, str] = {}
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
future_map = {
ex.submit(self._call_vllm, cropped, prompt): idx
for idx, region, cropped, prompt in tasks
}
for future in as_completed(future_map):
idx = future_map[future]
try:
raw_results[idx] = future.result()
except Exception:
raw_results[idx] = ""
# Build structured region dicts for GLMResultFormatter
region_dicts = []
for idx, region, _cropped, _prompt in tasks:
region_dicts.append(
{
"index": idx,
"label": region.type,
"native_label": region.native_label,
"content": raw_results.get(idx, ""),
"bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h),
}
)
# 4. GLM-OCR postprocessing: clean, format, merge, bullets
markdown_content = self._formatter.process(region_dicts)
# 5. LaTeX math error correction (our existing pipeline)
markdown_content = _postprocess_markdown(markdown_content)
# 6. Format conversion
latex, mathml, mml = "", "", ""
if markdown_content and self.converter:
fmt = self.converter.convert_to_formats(markdown_content)
latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/formula2.jpg")