feat add glm-ocr core
This commit is contained in:
14
.claude/settings.local.json
Normal file
14
.claude/settings.local.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"WebFetch(domain:deepwiki.com)",
|
||||
"WebFetch(domain:github.com)",
|
||||
"Read(//private/tmp/**)",
|
||||
"Bash(gh api repos/zai-org/GLM-OCR/contents/glmocr --jq '.[].name')",
|
||||
"WebFetch(domain:raw.githubusercontent.com)",
|
||||
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
|
||||
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
|
||||
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -2,26 +2,17 @@
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import cv2
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
from app.core.dependencies import (
|
||||
get_image_processor,
|
||||
get_layout_detector,
|
||||
get_ocr_service,
|
||||
get_mineru_ocr_service,
|
||||
get_glmocr_service,
|
||||
get_glmocr_endtoend_service,
|
||||
)
|
||||
from app.core.config import get_settings
|
||||
from app.core.logging_config import get_logger, RequestIDAdapter
|
||||
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
|
||||
|
||||
settings = get_settings()
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger()
|
||||
@@ -33,100 +24,38 @@ async def process_image_ocr(
|
||||
http_request: Request,
|
||||
response: Response,
|
||||
image_processor: ImageProcessor = Depends(get_image_processor),
|
||||
layout_detector: LayoutDetector = Depends(get_layout_detector),
|
||||
mineru_service: MineruOCRService = Depends(get_mineru_ocr_service),
|
||||
paddle_service: OCRService = Depends(get_ocr_service),
|
||||
glmocr_service: GLMOCRService = Depends(get_glmocr_service),
|
||||
glmocr_service: GLMOCREndToEndService = Depends(get_glmocr_endtoend_service),
|
||||
) -> ImageOCRResponse:
|
||||
"""Process an image and extract content as LaTeX, Markdown, and MathML.
|
||||
|
||||
The processing pipeline:
|
||||
1. Load and preprocess image (add 30% whitespace padding)
|
||||
2. Detect layout using DocLayout-YOLO
|
||||
3. Based on layout:
|
||||
- If plain text exists: use PP-DocLayoutV2 for mixed recognition
|
||||
- Otherwise: use PaddleOCR-VL with formula prompt
|
||||
4. Convert output to LaTeX, Markdown, and MathML formats
|
||||
1. Load and preprocess image
|
||||
2. Detect layout regions using PP-DocLayoutV3
|
||||
3. Crop each region and recognize with GLM-OCR via vLLM (task-specific prompts)
|
||||
4. Aggregate region results into Markdown
|
||||
5. Convert to LaTeX, Markdown, and MathML formats
|
||||
|
||||
Note: OMML conversion is not included due to performance overhead.
|
||||
Use the /convert/latex-to-omml endpoint to convert LaTeX to OMML separately.
|
||||
"""
|
||||
# Get or generate request ID
|
||||
request_id = http_request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
response.headers["x-request-id"] = request_id
|
||||
|
||||
# Create logger adapter with request_id
|
||||
log = RequestIDAdapter(logger, {"request_id": request_id})
|
||||
log.request_id = request_id
|
||||
|
||||
try:
|
||||
log.info("Starting image OCR processing")
|
||||
start = time.time()
|
||||
|
||||
# Preprocess image (load only, no padding yet)
|
||||
preprocess_start = time.time()
|
||||
image = image_processor.preprocess(
|
||||
image_url=request.image_url,
|
||||
image_base64=request.image_base64,
|
||||
)
|
||||
|
||||
# Apply padding only for layout detection
|
||||
processed_image = image
|
||||
if image_processor and settings.is_padding:
|
||||
processed_image = image_processor.add_padding(image)
|
||||
ocr_result = glmocr_service.recognize(image)
|
||||
|
||||
preprocess_time = time.time() - preprocess_start
|
||||
log.debug(f"Image loading completed in {preprocess_time:.3f}s")
|
||||
|
||||
# Layout detection (using padded image if padding is enabled)
|
||||
layout_start = time.time()
|
||||
layout_info = layout_detector.detect(processed_image)
|
||||
layout_time = time.time() - layout_start
|
||||
log.info(f"Layout detection completed in {layout_time:.3f}s")
|
||||
|
||||
# OCR recognition (use original image without padding)
|
||||
ocr_start = time.time()
|
||||
if layout_info.MixedRecognition:
|
||||
recognition_method = "MixedRecognition (MinerU)"
|
||||
log.info(f"Using {recognition_method}")
|
||||
|
||||
# Convert original image (without padding) to bytes
|
||||
success, encoded_image = cv2.imencode(".png", image)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image")
|
||||
|
||||
image_bytes = BytesIO(encoded_image.tobytes())
|
||||
image_bytes.seek(0) # Ensure position is at the beginning
|
||||
ocr_result = mineru_service.recognize(image_bytes)
|
||||
else:
|
||||
recognition_method = "FormulaOnly (GLMOCR)"
|
||||
log.info(f"Using {recognition_method}")
|
||||
|
||||
# Try GLM-OCR first, fallback to MinerU if token limit exceeded
|
||||
try:
|
||||
ocr_result = glmocr_service.recognize(image)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Check if error is due to token limit (max_model_len exceeded)
|
||||
if "max_model_len" in error_msg or "decoder prompt" in error_msg or "BadRequestError" in error_msg:
|
||||
log.warning(f"GLM-OCR failed due to token limit: {error_msg}")
|
||||
log.info("Falling back to MinerU for recognition")
|
||||
recognition_method = "FormulaOnly (MinerU fallback)"
|
||||
|
||||
# Convert original image to bytes for MinerU
|
||||
success, encoded_image = cv2.imencode(".png", image)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image")
|
||||
|
||||
image_bytes = BytesIO(encoded_image.tobytes())
|
||||
image_bytes.seek(0)
|
||||
ocr_result = mineru_service.recognize(image_bytes)
|
||||
else:
|
||||
# Re-raise other errors
|
||||
raise
|
||||
ocr_time = time.time() - ocr_start
|
||||
|
||||
total_time = time.time() - preprocess_start
|
||||
log.info(f"OCR processing completed - Method: {recognition_method}, " f"Layout time: {layout_time:.3f}s, OCR time: {ocr_time:.3f}s, " f"Total time: {total_time:.3f}s")
|
||||
log.info(f"OCR completed in {time.time() - start:.3f}s")
|
||||
|
||||
except RuntimeError as e:
|
||||
log.error(f"OCR processing failed: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import torch
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -48,21 +47,25 @@ class Settings(BaseSettings):
|
||||
is_padding: bool = True
|
||||
padding_ratio: float = 0.1
|
||||
|
||||
max_tokens: int = 4096
|
||||
|
||||
# Model Paths
|
||||
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
||||
pp_doclayout_model_dir: str | None = (
|
||||
"/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
||||
)
|
||||
|
||||
# Image Processing
|
||||
max_image_size_mb: int = 10
|
||||
image_padding_ratio: float = 0.1 # 10% on each side = 20% total expansion
|
||||
|
||||
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # cuda:0 or cpu
|
||||
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Server Settings
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8053
|
||||
|
||||
# Logging Settings
|
||||
log_dir: Optional[str] = None # Defaults to /app/logs in container or ./logs locally
|
||||
log_dir: str | None = None # Defaults to /app/logs in container or ./logs locally
|
||||
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
from app.services.converter import Converter
|
||||
from app.core.config import get_settings
|
||||
|
||||
@@ -31,40 +31,17 @@ def get_image_processor() -> ImageProcessor:
|
||||
return ImageProcessor()
|
||||
|
||||
|
||||
def get_ocr_service() -> OCRService:
|
||||
"""Get an OCR service instance."""
|
||||
return OCRService(
|
||||
vl_server_url=get_settings().paddleocr_vl_url,
|
||||
layout_detector=get_layout_detector(),
|
||||
image_processor=get_image_processor(),
|
||||
converter=get_converter(),
|
||||
)
|
||||
|
||||
|
||||
def get_converter() -> Converter:
|
||||
"""Get a DOCX converter instance."""
|
||||
return Converter()
|
||||
|
||||
|
||||
def get_mineru_ocr_service() -> MineruOCRService:
|
||||
"""Get a MinerOCR service instance."""
|
||||
def get_glmocr_endtoend_service() -> GLMOCREndToEndService:
|
||||
"""Get end-to-end GLM-OCR service (layout detection + per-region OCR)."""
|
||||
settings = get_settings()
|
||||
api_url = getattr(settings, "miner_ocr_api_url", "http://127.0.0.1:8000/file_parse")
|
||||
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://localhost:8002/v1")
|
||||
return MineruOCRService(
|
||||
api_url=api_url,
|
||||
converter=get_converter(),
|
||||
image_processor=get_image_processor(),
|
||||
glm_ocr_url=glm_ocr_url,
|
||||
)
|
||||
|
||||
|
||||
def get_glmocr_service() -> GLMOCRService:
|
||||
"""Get a GLM OCR service instance."""
|
||||
settings = get_settings()
|
||||
glm_ocr_url = getattr(settings, "glm_ocr_url", "http://127.0.0.1:8002/v1")
|
||||
return GLMOCRService(
|
||||
vl_server_url=glm_ocr_url,
|
||||
return GLMOCREndToEndService(
|
||||
vl_server_url=settings.glm_ocr_url,
|
||||
image_processor=get_image_processor(),
|
||||
converter=get_converter(),
|
||||
layout_detector=get_layout_detector(),
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ class LayoutRegion(BaseModel):
|
||||
"""A detected layout region in the document."""
|
||||
|
||||
type: str = Field(..., description="Region type: text, formula, table, figure")
|
||||
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
|
||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||
confidence: float = Field(..., description="Detection confidence score")
|
||||
score: float = Field(..., description="Detection score")
|
||||
|
||||
412
app/services/glm_postprocess.py
Normal file
412
app/services/glm_postprocess.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""GLM-OCR postprocessing logic adapted for this project.
|
||||
|
||||
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
|
||||
glm-ocr/glmocr/utils/result_postprocess_utils.py.
|
||||
|
||||
Covers:
|
||||
- Repeated-content / hallucination detection
|
||||
- Per-region content cleaning and formatting (titles, bullets, formulas)
|
||||
- formula_number merging (→ \\tag{})
|
||||
- Hyphenated text-block merging (via wordfreq)
|
||||
- Missing bullet-point detection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from wordfreq import zipf_frequency
|
||||
|
||||
_WORDFREQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
_WORDFREQ_AVAILABLE = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# result_postprocess_utils (ported)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
|
||||
"""Detect and truncate a consecutively-repeated pattern.
|
||||
|
||||
Returns the string with the repeat removed, or None if not found.
|
||||
"""
|
||||
n = len(s)
|
||||
if n < min_unit_len * min_repeats:
|
||||
return None
|
||||
|
||||
max_unit_len = n // min_repeats
|
||||
if max_unit_len < min_unit_len:
|
||||
return None
|
||||
|
||||
pattern = re.compile(
|
||||
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
|
||||
re.DOTALL,
|
||||
)
|
||||
match = pattern.search(s)
|
||||
if match:
|
||||
return s[: match.start()] + match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def clean_repeated_content(
|
||||
content: str,
|
||||
min_len: int = 10,
|
||||
min_repeats: int = 10,
|
||||
line_threshold: int = 10,
|
||||
) -> str:
|
||||
"""Remove hallucination-style repeated content (consecutive or line-level)."""
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return content
|
||||
|
||||
# 1. Consecutive repeat (multi-line aware)
|
||||
if len(stripped) > min_len * min_repeats:
|
||||
result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# 2. Line-level repeat
|
||||
lines = [line.strip() for line in content.split("\n") if line.strip()]
|
||||
total_lines = len(lines)
|
||||
if total_lines >= line_threshold and lines:
|
||||
common, count = Counter(lines).most_common(1)[0]
|
||||
if count >= line_threshold and (count / total_lines) >= 0.8:
|
||||
for i, line in enumerate(lines):
|
||||
if line == common:
|
||||
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
|
||||
if consecutive >= 3:
|
||||
original_lines = content.split("\n")
|
||||
non_empty_count = 0
|
||||
for idx, orig_line in enumerate(original_lines):
|
||||
if orig_line.strip():
|
||||
non_empty_count += 1
|
||||
if non_empty_count == i + 1:
|
||||
return "\n".join(original_lines[: idx + 1])
|
||||
break
|
||||
return content
|
||||
|
||||
|
||||
def clean_formula_number(number_content: str) -> str:
|
||||
"""Strip parentheses from a formula number string, e.g. '(1)' → '1'."""
|
||||
s = number_content.strip()
|
||||
if s.startswith("(") and s.endswith(")"):
|
||||
return s[1:-1]
|
||||
if s.startswith("(") and s.endswith(")"):
|
||||
return s[1:-1]
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GLMResultFormatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
||||
_LABEL_TO_CATEGORY: Dict[str, str] = {
|
||||
# text
|
||||
"abstract": "text",
|
||||
"algorithm": "text",
|
||||
"content": "text",
|
||||
"doc_title": "text",
|
||||
"figure_title": "text",
|
||||
"paragraph_title": "text",
|
||||
"reference_content": "text",
|
||||
"text": "text",
|
||||
"vertical_text": "text",
|
||||
"vision_footnote": "text",
|
||||
"seal": "text",
|
||||
"formula_number": "text",
|
||||
# table
|
||||
"table": "table",
|
||||
# formula
|
||||
"display_formula": "formula",
|
||||
"inline_formula": "formula",
|
||||
# image (skip OCR)
|
||||
"chart": "image",
|
||||
"image": "image",
|
||||
}
|
||||
|
||||
|
||||
class GLMResultFormatter:
|
||||
"""Port of GLM-OCR's ResultFormatter for use in our pipeline.
|
||||
|
||||
Accepts a list of region dicts (each with label, native_label, content,
|
||||
bbox_2d) and returns a final Markdown string.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public entry-point
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def process(self, regions: List[Dict[str, Any]]) -> str:
|
||||
"""Run the full postprocessing pipeline and return Markdown.
|
||||
|
||||
Args:
|
||||
regions: List of dicts with keys:
|
||||
- index (int) reading order from layout detection
|
||||
- label (str) mapped category: text/formula/table/figure
|
||||
- native_label (str) raw PP-DocLayout label (e.g. doc_title)
|
||||
- content (str) raw OCR output from vLLM
|
||||
- bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords
|
||||
|
||||
Returns:
|
||||
Markdown string.
|
||||
"""
|
||||
# Sort by reading order
|
||||
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
|
||||
|
||||
# Per-region cleaning + formatting
|
||||
processed: List[Dict] = []
|
||||
for item in items:
|
||||
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
||||
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
||||
|
||||
item["content"] = self._format_content(
|
||||
item.get("content") or "",
|
||||
item["label"],
|
||||
item["native_label"],
|
||||
)
|
||||
if not (item.get("content") or "").strip():
|
||||
continue
|
||||
processed.append(item)
|
||||
|
||||
# Re-index
|
||||
for i, item in enumerate(processed):
|
||||
item["index"] = i
|
||||
|
||||
# Structural merges
|
||||
processed = self._merge_formula_numbers(processed)
|
||||
processed = self._merge_text_blocks(processed)
|
||||
processed = self._format_bullet_points(processed)
|
||||
|
||||
# Assemble Markdown
|
||||
parts: List[str] = []
|
||||
for item in processed:
|
||||
content = item.get("content") or ""
|
||||
if item["label"] == "image":
|
||||
parts.append(f"})")
|
||||
elif content.strip():
|
||||
parts.append(content)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Label mapping
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _map_label(self, label: str, native_label: str) -> str:
|
||||
return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text"))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Content cleaning
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _clean_content(self, content: str) -> str:
|
||||
"""Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats."""
|
||||
if content is None:
|
||||
return ""
|
||||
|
||||
content = re.sub(r"^(\\t)+", "", content).lstrip()
|
||||
content = re.sub(r"(\\t)+$", "", content).rstrip()
|
||||
|
||||
content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content)
|
||||
content = re.sub(r"(·)\1{2,}", r"\1\1\1", content)
|
||||
content = re.sub(r"(_)\1{2,}", r"\1\1\1", content)
|
||||
content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content)
|
||||
|
||||
if len(content) >= 2048:
|
||||
content = clean_repeated_content(content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Per-region content formatting
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _format_content(self, content: Any, label: str, native_label: str) -> str:
|
||||
"""Clean and format a single region's content."""
|
||||
if content is None:
|
||||
return ""
|
||||
|
||||
content = self._clean_content(str(content))
|
||||
|
||||
# Heading formatting
|
||||
if native_label == "doc_title":
|
||||
content = re.sub(r"^#+\s*", "", content)
|
||||
content = "# " + content
|
||||
elif native_label == "paragraph_title":
|
||||
if content.startswith("- ") or content.startswith("* "):
|
||||
content = content[2:].lstrip()
|
||||
content = re.sub(r"^#+\s*", "", content)
|
||||
content = "## " + content.lstrip()
|
||||
|
||||
# Formula wrapping
|
||||
if label == "formula":
|
||||
content = content.strip()
|
||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
||||
if content.startswith(s) and content.endswith(e):
|
||||
content = content[len(s) : -len(e)].strip()
|
||||
break
|
||||
content = "$$\n" + content + "\n$$"
|
||||
|
||||
# Text formatting
|
||||
if label == "text":
|
||||
if content.startswith("·") or content.startswith("•") or content.startswith("* "):
|
||||
content = "- " + content[1:].lstrip()
|
||||
|
||||
match = re.match(r"^(\(|\()(\d+|[A-Za-z])(\)|\))(.*)$", content)
|
||||
if match:
|
||||
_, symbol, _, rest = match.groups()
|
||||
content = f"({symbol}) {rest.lstrip()}"
|
||||
|
||||
match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\))(.*)$", content)
|
||||
if match:
|
||||
symbol, sep, rest = match.groups()
|
||||
sep = ")" if sep == ")" else sep
|
||||
content = f"{symbol}{sep} {rest.lstrip()}"
|
||||
|
||||
# Single newline → double newline
|
||||
content = re.sub(r"(?<!\n)\n(?!\n)", "\n\n", content)
|
||||
|
||||
return content
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Structural merges
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
|
||||
"""Merge formula_number region into adjacent formula with \\tag{}."""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
merged: List[Dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
if i in skip:
|
||||
continue
|
||||
|
||||
native = block.get("native_label", "")
|
||||
|
||||
# Case 1: formula_number then formula
|
||||
if native == "formula_number":
|
||||
if i + 1 < len(items) and items[i + 1].get("label") == "formula":
|
||||
num_clean = clean_formula_number(block.get("content", "").strip())
|
||||
formula_content = items[i + 1].get("content", "")
|
||||
merged_block = deepcopy(items[i + 1])
|
||||
if formula_content.endswith("\n$$"):
|
||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
merged.append(merged_block)
|
||||
skip.add(i + 1)
|
||||
continue # always skip the formula_number block itself
|
||||
|
||||
# Case 2: formula then formula_number
|
||||
if block.get("label") == "formula":
|
||||
if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number":
|
||||
num_clean = clean_formula_number(items[i + 1].get("content", "").strip())
|
||||
formula_content = block.get("content", "")
|
||||
merged_block = deepcopy(block)
|
||||
if formula_content.endswith("\n$$"):
|
||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
merged.append(merged_block)
|
||||
skip.add(i + 1)
|
||||
continue
|
||||
|
||||
merged.append(block)
|
||||
|
||||
for i, block in enumerate(merged):
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
|
||||
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
|
||||
if not items or not _WORDFREQ_AVAILABLE:
|
||||
return items
|
||||
|
||||
merged: List[Dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
if i in skip:
|
||||
continue
|
||||
if block.get("label") != "text":
|
||||
merged.append(block)
|
||||
continue
|
||||
|
||||
content = block.get("content", "")
|
||||
if not isinstance(content, str) or not content.rstrip().endswith("-"):
|
||||
merged.append(block)
|
||||
continue
|
||||
|
||||
content_stripped = content.rstrip()
|
||||
did_merge = False
|
||||
for j in range(i + 1, len(items)):
|
||||
if items[j].get("label") != "text":
|
||||
continue
|
||||
next_content = items[j].get("content", "")
|
||||
if not isinstance(next_content, str):
|
||||
continue
|
||||
next_stripped = next_content.lstrip()
|
||||
if next_stripped and next_stripped[0].islower():
|
||||
words_before = content_stripped[:-1].split()
|
||||
next_words = next_stripped.split()
|
||||
if words_before and next_words:
|
||||
merged_word = words_before[-1] + next_words[0]
|
||||
if zipf_frequency(merged_word.lower(), "en") >= 2.5:
|
||||
merged_block = deepcopy(block)
|
||||
merged_block["content"] = content_stripped[:-1] + next_content.lstrip()
|
||||
merged.append(merged_block)
|
||||
skip.add(j)
|
||||
did_merge = True
|
||||
break
|
||||
|
||||
if not did_merge:
|
||||
merged.append(block)
|
||||
|
||||
for i, block in enumerate(merged):
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
|
||||
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
||||
if len(items) < 3:
|
||||
return items
|
||||
|
||||
for i in range(1, len(items) - 1):
|
||||
cur = items[i]
|
||||
prev = items[i - 1]
|
||||
nxt = items[i + 1]
|
||||
|
||||
if cur.get("native_label") != "text":
|
||||
continue
|
||||
if prev.get("native_label") != "text" or nxt.get("native_label") != "text":
|
||||
continue
|
||||
|
||||
cur_content = cur.get("content", "")
|
||||
if cur_content.startswith("- "):
|
||||
continue
|
||||
|
||||
prev_content = prev.get("content", "")
|
||||
nxt_content = nxt.get("content", "")
|
||||
if not (prev_content.startswith("- ") and nxt_content.startswith("- ")):
|
||||
continue
|
||||
|
||||
cur_bbox = cur.get("bbox_2d", [])
|
||||
prev_bbox = prev.get("bbox_2d", [])
|
||||
nxt_bbox = nxt.get("bbox_2d", [])
|
||||
if not (cur_bbox and prev_bbox and nxt_bbox):
|
||||
continue
|
||||
|
||||
if (
|
||||
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
|
||||
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
|
||||
):
|
||||
cur["content"] = "- " + cur_content
|
||||
|
||||
return items
|
||||
@@ -1,9 +1,10 @@
|
||||
"""PP-DocLayoutV2 wrapper for document layout detection."""
|
||||
"""PP-DocLayoutV3 wrapper for document layout detection."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.core.config import get_settings
|
||||
from app.services.layout_postprocess import apply_layout_postprocess
|
||||
from paddleocr import LayoutDetection
|
||||
from typing import Optional
|
||||
|
||||
@@ -116,6 +117,17 @@ class LayoutDetector:
|
||||
else:
|
||||
boxes = []
|
||||
|
||||
# Apply GLM-OCR layout post-processing (NMS, containment, unclip, clamp)
|
||||
if boxes:
|
||||
h, w = image.shape[:2]
|
||||
boxes = apply_layout_postprocess(
|
||||
boxes,
|
||||
img_size=(w, h),
|
||||
layout_nms=True,
|
||||
layout_unclip_ratio=None,
|
||||
layout_merge_bboxes_mode="large",
|
||||
)
|
||||
|
||||
for box in boxes:
|
||||
cls_id = box.get("cls_id")
|
||||
label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other")
|
||||
@@ -128,6 +140,7 @@ class LayoutDetector:
|
||||
regions.append(
|
||||
LayoutRegion(
|
||||
type=region_type,
|
||||
native_label=label,
|
||||
bbox=coordinate,
|
||||
confidence=score,
|
||||
score=score,
|
||||
|
||||
343
app/services/layout_postprocess.py
Normal file
343
app/services/layout_postprocess.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Layout post-processing utilities ported from GLM-OCR.
|
||||
|
||||
Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py
|
||||
|
||||
Algorithms applied after PaddleOCR LayoutDetection.predict():
|
||||
1. NMS with dual IoU thresholds (same-class vs cross-class)
|
||||
2. Large-image-region filtering (remove image boxes that fill most of the page)
|
||||
3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child)
|
||||
4. Unclip ratio (optional bbox expansion)
|
||||
5. Invalid bbox skipping
|
||||
|
||||
These steps run on top of PaddleOCR's built-in detection to replicate
|
||||
the quality of the GLM-OCR SDK's layout pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Primitive geometry helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def iou(box1: List[float], box2: List[float]) -> float:
|
||||
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
|
||||
x1, y1, x2, y2 = box1
|
||||
x1_p, y1_p, x2_p, y2_p = box2
|
||||
|
||||
x1_i = max(x1, x1_p)
|
||||
y1_i = max(y1, y1_p)
|
||||
x2_i = min(x2, x2_p)
|
||||
y2_i = min(y2, y2_p)
|
||||
|
||||
inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
|
||||
box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
|
||||
|
||||
return inter_area / float(box1_area + box2_area - inter_area)
|
||||
|
||||
|
||||
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
|
||||
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
|
||||
|
||||
box format: [cls_id, score, x1, y1, x2, y2]
|
||||
"""
|
||||
_, _, x1, y1, x2, y2 = box1
|
||||
_, _, x1_p, y1_p, x2_p, y2_p = box2
|
||||
|
||||
box1_area = (x2 - x1) * (y2 - y1)
|
||||
if box1_area <= 0:
|
||||
return False
|
||||
|
||||
xi1 = max(x1, x1_p)
|
||||
yi1 = max(y1, y1_p)
|
||||
xi2 = min(x2, x2_p)
|
||||
yi2 = min(y2, y2_p)
|
||||
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||||
|
||||
return (inter_area / box1_area) >= overlap_threshold
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NMS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def nms(
|
||||
boxes: np.ndarray,
|
||||
iou_same: float = 0.6,
|
||||
iou_diff: float = 0.98,
|
||||
) -> List[int]:
|
||||
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
|
||||
|
||||
Args:
|
||||
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||||
iou_same: Suppression threshold for boxes of the same class.
|
||||
iou_diff: Suppression threshold for boxes of different classes.
|
||||
|
||||
Returns:
|
||||
List of kept row indices.
|
||||
"""
|
||||
scores = boxes[:, 1]
|
||||
indices = np.argsort(scores)[::-1].tolist()
|
||||
selected: List[int] = []
|
||||
|
||||
while indices:
|
||||
current = indices[0]
|
||||
selected.append(current)
|
||||
current_class = int(boxes[current, 0])
|
||||
current_coords = boxes[current, 2:6].tolist()
|
||||
indices = indices[1:]
|
||||
|
||||
kept = []
|
||||
for i in indices:
|
||||
box_class = int(boxes[i, 0])
|
||||
box_coords = boxes[i, 2:6].tolist()
|
||||
threshold = iou_same if current_class == box_class else iou_diff
|
||||
if iou(current_coords, box_coords) < threshold:
|
||||
kept.append(i)
|
||||
indices = kept
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Containment analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Labels whose regions should never be removed even when contained in another box
|
||||
_PRESERVE_LABELS = {"image", "seal", "chart"}
|
||||
|
||||
|
||||
def check_containment(
|
||||
boxes: np.ndarray,
|
||||
preserve_cls_ids: Optional[set] = None,
|
||||
category_index: Optional[int] = None,
|
||||
mode: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute containment flags for each box.
|
||||
|
||||
Args:
|
||||
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||||
preserve_cls_ids: Class IDs that must never be marked as contained.
|
||||
category_index: If set, apply mode only relative to this class.
|
||||
mode: 'large' or 'small' (only used with category_index).
|
||||
|
||||
Returns:
|
||||
(contains_other, contained_by_other): boolean arrays of length N.
|
||||
"""
|
||||
n = len(boxes)
|
||||
contains_other = np.zeros(n, dtype=int)
|
||||
contained_by_other = np.zeros(n, dtype=int)
|
||||
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if i == j:
|
||||
continue
|
||||
if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids:
|
||||
continue
|
||||
if category_index is not None and mode is not None:
|
||||
if mode == "large" and int(boxes[j, 0]) == category_index:
|
||||
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||||
contained_by_other[i] = 1
|
||||
contains_other[j] = 1
|
||||
elif mode == "small" and int(boxes[i, 0]) == category_index:
|
||||
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||||
contained_by_other[i] = 1
|
||||
contains_other[j] = 1
|
||||
else:
|
||||
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||||
contained_by_other[i] = 1
|
||||
contains_other[j] = 1
|
||||
|
||||
return contains_other, contained_by_other
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Box expansion (unclip)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def unclip_boxes(
|
||||
boxes: np.ndarray,
|
||||
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
|
||||
) -> np.ndarray:
|
||||
"""Expand bounding boxes by the given ratio.
|
||||
|
||||
Args:
|
||||
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||||
unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio.
|
||||
|
||||
Returns:
|
||||
Expanded boxes array.
|
||||
"""
|
||||
if unclip_ratio is None:
|
||||
return boxes
|
||||
|
||||
if isinstance(unclip_ratio, dict):
|
||||
expanded = []
|
||||
for box in boxes:
|
||||
cls_id = int(box[0])
|
||||
if cls_id in unclip_ratio:
|
||||
w_ratio, h_ratio = unclip_ratio[cls_id]
|
||||
x1, y1, x2, y2 = box[2], box[3], box[4], box[5]
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio
|
||||
new_box = list(box)
|
||||
new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2
|
||||
new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2
|
||||
expanded.append(new_box)
|
||||
else:
|
||||
expanded.append(list(box))
|
||||
return np.array(expanded)
|
||||
|
||||
# Scalar or tuple
|
||||
if isinstance(unclip_ratio, (int, float)):
|
||||
unclip_ratio = (float(unclip_ratio), float(unclip_ratio))
|
||||
|
||||
w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1]
|
||||
widths = boxes[:, 4] - boxes[:, 2]
|
||||
heights = boxes[:, 5] - boxes[:, 3]
|
||||
cx = boxes[:, 2] + widths / 2
|
||||
cy = boxes[:, 3] + heights / 2
|
||||
nw, nh = widths * w_ratio, heights * h_ratio
|
||||
expanded = boxes.copy().astype(float)
|
||||
expanded[:, 2] = cx - nw / 2
|
||||
expanded[:, 3] = cy - nh / 2
|
||||
expanded[:, 4] = cx + nw / 2
|
||||
expanded[:, 5] = cy + nh / 2
|
||||
return expanded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry-point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_layout_postprocess(
|
||||
boxes: List[Dict],
|
||||
img_size: Tuple[int, int],
|
||||
layout_nms: bool = True,
|
||||
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
|
||||
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
|
||||
) -> List[Dict]:
|
||||
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
|
||||
|
||||
Args:
|
||||
boxes: PaddleOCR output — list of dicts with keys:
|
||||
cls_id, label, score, coordinate ([x1, y1, x2, y2]).
|
||||
img_size: (width, height) of the image.
|
||||
layout_nms: Apply dual-threshold NMS.
|
||||
layout_unclip_ratio: Optional bbox expansion ratio.
|
||||
layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small',
|
||||
'union', or per-class dict.
|
||||
|
||||
Returns:
|
||||
Filtered and ordered list of box dicts in the same PaddleOCR format.
|
||||
"""
|
||||
if not boxes:
|
||||
return boxes
|
||||
|
||||
img_width, img_height = img_size
|
||||
|
||||
# --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- #
|
||||
arr_rows = []
|
||||
for b in boxes:
|
||||
cls_id = b.get("cls_id", 0)
|
||||
score = b.get("score", 0.0)
|
||||
x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0])
|
||||
arr_rows.append([cls_id, score, x1, y1, x2, y2])
|
||||
boxes_array = np.array(arr_rows, dtype=float)
|
||||
|
||||
all_labels: List[str] = [b.get("label", "") for b in boxes]
|
||||
|
||||
# 1. NMS ---------------------------------------------------------------- #
|
||||
if layout_nms and len(boxes_array) > 1:
|
||||
kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98)
|
||||
boxes_array = boxes_array[kept]
|
||||
all_labels = [all_labels[k] for k in kept]
|
||||
|
||||
# 2. Filter large image regions ---------------------------------------- #
|
||||
if len(boxes_array) > 1:
|
||||
img_area = img_width * img_height
|
||||
area_thres = 0.82 if img_width > img_height else 0.93
|
||||
image_cls_ids = {
|
||||
int(boxes_array[i, 0])
|
||||
for i, lbl in enumerate(all_labels)
|
||||
if lbl == "image"
|
||||
}
|
||||
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||||
for i, lbl in enumerate(all_labels):
|
||||
if lbl == "image":
|
||||
x1, y1, x2, y2 = boxes_array[i, 2:6]
|
||||
x1 = max(0.0, x1); y1 = max(0.0, y1)
|
||||
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
|
||||
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
|
||||
keep_mask[i] = False
|
||||
boxes_array = boxes_array[keep_mask]
|
||||
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||||
|
||||
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
|
||||
if layout_merge_bboxes_mode and len(boxes_array) > 1:
|
||||
preserve_cls_ids = {
|
||||
int(boxes_array[i, 0])
|
||||
for i, lbl in enumerate(all_labels)
|
||||
if lbl in _PRESERVE_LABELS
|
||||
}
|
||||
|
||||
if isinstance(layout_merge_bboxes_mode, str):
|
||||
mode = layout_merge_bboxes_mode
|
||||
if mode in ("large", "small"):
|
||||
contains_other, contained_by_other = check_containment(
|
||||
boxes_array, preserve_cls_ids
|
||||
)
|
||||
if mode == "large":
|
||||
keep_mask = contained_by_other == 0
|
||||
else:
|
||||
keep_mask = (contains_other == 0) | (contained_by_other == 1)
|
||||
boxes_array = boxes_array[keep_mask]
|
||||
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||||
|
||||
elif isinstance(layout_merge_bboxes_mode, dict):
|
||||
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||||
for category_index, mode in layout_merge_bboxes_mode.items():
|
||||
if mode in ("large", "small"):
|
||||
contains_other, contained_by_other = check_containment(
|
||||
boxes_array, preserve_cls_ids, int(category_index), mode
|
||||
)
|
||||
if mode == "large":
|
||||
keep_mask &= contained_by_other == 0
|
||||
else:
|
||||
keep_mask &= (contains_other == 0) | (contained_by_other == 1)
|
||||
boxes_array = boxes_array[keep_mask]
|
||||
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||||
|
||||
if len(boxes_array) == 0:
|
||||
return []
|
||||
|
||||
# 4. Unclip (bbox expansion) ------------------------------------------- #
|
||||
if layout_unclip_ratio is not None:
|
||||
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
|
||||
|
||||
# 5. Clamp to image boundaries + skip invalid -------------------------- #
|
||||
result: List[Dict] = []
|
||||
for i, row in enumerate(boxes_array):
|
||||
cls_id = int(row[0])
|
||||
score = float(row[1])
|
||||
x1 = max(0.0, min(float(row[2]), img_width))
|
||||
y1 = max(0.0, min(float(row[3]), img_height))
|
||||
x2 = max(0.0, min(float(row[4]), img_width))
|
||||
y2 = max(0.0, min(float(row[5]), img_height))
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
continue
|
||||
|
||||
result.append({
|
||||
"cls_id": cls_id,
|
||||
"label": all_labels[i],
|
||||
"score": score,
|
||||
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -1,19 +1,23 @@
|
||||
"""PaddleOCR-VL client service for text and formula recognition."""
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from app.core.config import get_settings
|
||||
from paddleocr import PaddleOCRVL
|
||||
from typing import Optional
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.converter import Converter
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from paddleocr import PaddleOCRVL
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.converter import Converter
|
||||
from app.services.glm_postprocess import GLMResultFormatter
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
|
||||
# Strategy: remove spaces before \ and between non-command chars,
|
||||
# but preserve the space after \command when followed by a non-\ char
|
||||
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
|
||||
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
|
||||
cleaned = re.sub(
|
||||
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
|
||||
) # remove space after non-letter non-\
|
||||
return f"{operator}{{{cleaned}}}"
|
||||
|
||||
# Match _{ ... } or ^{ ... }
|
||||
@@ -383,8 +389,8 @@ class OCRServiceBase(ABC):
|
||||
class OCRService(OCRServiceBase):
|
||||
"""Service for OCR using PaddleOCR-VL."""
|
||||
|
||||
_pipeline: Optional[PaddleOCRVL] = None
|
||||
_layout_detector: Optional[LayoutDetector] = None
|
||||
_pipeline: PaddleOCRVL | None = None
|
||||
_layout_detector: LayoutDetector | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -549,7 +555,15 @@ class GLMOCRService(OCRServiceBase):
|
||||
|
||||
# Call OpenAI-compatible API with formula recognition prompt
|
||||
prompt = "Formula Recognition:"
|
||||
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Don't catch exceptions here - let them propagate for fallback handling
|
||||
response = self.openai_client.chat.completions.create(
|
||||
@@ -596,10 +610,10 @@ class MineruOCRService(OCRServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str = "http://127.0.0.1:8000/file_parse",
|
||||
image_processor: Optional[ImageProcessor] = None,
|
||||
converter: Optional[Converter] = None,
|
||||
image_processor: ImageProcessor | None = None,
|
||||
converter: Converter | None = None,
|
||||
glm_ocr_url: str = "http://localhost:8002/v1",
|
||||
layout_detector: Optional[LayoutDetector] = None,
|
||||
layout_detector: LayoutDetector | None = None,
|
||||
):
|
||||
"""Initialize Local API service.
|
||||
|
||||
@@ -614,7 +628,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
self.glm_ocr_url = glm_ocr_url
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
|
||||
|
||||
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
|
||||
def _recognize_formula_with_paddleocr_vl(
|
||||
self, image: np.ndarray, prompt: str = "Formula Recognition:"
|
||||
) -> str:
|
||||
"""Recognize formula using PaddleOCR-VL API.
|
||||
|
||||
Args:
|
||||
@@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase):
|
||||
image_url = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
# Call OpenAI-compatible API
|
||||
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="glm-ocr",
|
||||
@@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
|
||||
|
||||
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
|
||||
def _extract_and_recognize_formulas(
|
||||
self, markdown_content: str, original_image: np.ndarray
|
||||
) -> str:
|
||||
"""Extract image references from markdown and recognize formulas.
|
||||
|
||||
Args:
|
||||
@@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase):
|
||||
}
|
||||
|
||||
# Make API request
|
||||
response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30)
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
files=files,
|
||||
data=data,
|
||||
headers={"accept": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
@@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
markdown_content = result["results"]["image"].get("md_content", "")
|
||||
|
||||
if "
|
||||
markdown_content = self._extract_and_recognize_formulas(
|
||||
markdown_content, original_image
|
||||
)
|
||||
|
||||
# Apply postprocessing to fix OCR errors
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
@@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase):
|
||||
raise RuntimeError(f"Recognition failed: {e}") from e
|
||||
|
||||
|
||||
# Task-specific prompts (from GLM-OCR SDK config.yaml)
|
||||
_TASK_PROMPTS: dict[str, str] = {
|
||||
"text": "Text Recognition:",
|
||||
"formula": "Formula Recognition:",
|
||||
"table": "Table Recognition:",
|
||||
}
|
||||
_DEFAULT_PROMPT = (
|
||||
"Recognize the text in the image and output in Markdown format. "
|
||||
"Preserve the original layout (headings/paragraphs/tables/formulas). "
|
||||
"Do not fabricate content that does not exist in the image."
|
||||
)
|
||||
|
||||
|
||||
class GLMOCREndToEndService(OCRServiceBase):
|
||||
"""End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR.
|
||||
|
||||
Pipeline:
|
||||
1. Add padding (ImageProcessor)
|
||||
2. Detect layout regions (LayoutDetector → PP-DocLayoutV3)
|
||||
3. Crop each region and call vLLM with a task-specific prompt (parallel)
|
||||
4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags
|
||||
5. _postprocess_markdown: LaTeX math error correction
|
||||
6. Converter: markdown → latex/mathml/mml
|
||||
|
||||
This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vl_server_url: str,
|
||||
image_processor: ImageProcessor,
|
||||
converter: Converter,
|
||||
layout_detector: LayoutDetector,
|
||||
max_workers: int = 8,
|
||||
):
|
||||
self.vl_server_url = vl_server_url or settings.glm_ocr_url
|
||||
self.image_processor = image_processor
|
||||
self.converter = converter
|
||||
self.layout_detector = layout_detector
|
||||
self.max_workers = max_workers
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
|
||||
self._formatter = GLMResultFormatter()
|
||||
|
||||
def _encode_region(self, image: np.ndarray) -> str:
|
||||
"""Convert BGR numpy array to base64 JPEG string."""
|
||||
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_img = PILImage.fromarray(rgb)
|
||||
buf = BytesIO()
|
||||
pil_img.save(buf, format="JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
def _call_vllm(self, image: np.ndarray, prompt: str) -> str:
|
||||
"""Send image + prompt to vLLM and return raw content string."""
|
||||
img_b64 = self._encode_region(image)
|
||||
data_url = f"data:image/jpeg;base64,{img_b64}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="glm-ocr",
|
||||
messages=messages,
|
||||
temperature=0.01,
|
||||
max_tokens=settings.max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]:
|
||||
"""Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords."""
|
||||
x1, y1, x2, y2 = bbox
|
||||
return [
|
||||
int(x1 / img_w * 1000),
|
||||
int(y1 / img_h * 1000),
|
||||
int(x2 / img_w * 1000),
|
||||
int(y2 / img_h * 1000),
|
||||
]
|
||||
|
||||
def recognize(self, image: np.ndarray) -> dict:
|
||||
"""Full pipeline: padding → layout → per-region OCR → postprocess → markdown.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array in BGR format.
|
||||
|
||||
Returns:
|
||||
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
|
||||
"""
|
||||
# 1. Padding
|
||||
padded = self.image_processor.add_padding(image)
|
||||
img_h, img_w = padded.shape[:2]
|
||||
|
||||
# 2. Layout detection
|
||||
layout_info = self.layout_detector.detect(padded)
|
||||
|
||||
# 3. OCR: per-region (parallel) or full-image fallback
|
||||
if not layout_info.regions:
|
||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
||||
markdown_content = self._formatter._clean_content(raw_content)
|
||||
else:
|
||||
# Build task list for non-figure regions
|
||||
tasks = []
|
||||
for idx, region in enumerate(layout_info.regions):
|
||||
if region.type == "figure":
|
||||
continue
|
||||
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
||||
cropped = padded[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
continue
|
||||
prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT)
|
||||
tasks.append((idx, region, cropped, prompt))
|
||||
|
||||
if not tasks:
|
||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
||||
markdown_content = self._formatter._clean_content(raw_content)
|
||||
else:
|
||||
# Parallel OCR calls
|
||||
raw_results: dict[int, str] = {}
|
||||
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
|
||||
future_map = {
|
||||
ex.submit(self._call_vllm, cropped, prompt): idx
|
||||
for idx, region, cropped, prompt in tasks
|
||||
}
|
||||
for future in as_completed(future_map):
|
||||
idx = future_map[future]
|
||||
try:
|
||||
raw_results[idx] = future.result()
|
||||
except Exception:
|
||||
raw_results[idx] = ""
|
||||
|
||||
# Build structured region dicts for GLMResultFormatter
|
||||
region_dicts = []
|
||||
for idx, region, _cropped, _prompt in tasks:
|
||||
region_dicts.append(
|
||||
{
|
||||
"index": idx,
|
||||
"label": region.type,
|
||||
"native_label": region.native_label,
|
||||
"content": raw_results.get(idx, ""),
|
||||
"bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h),
|
||||
}
|
||||
)
|
||||
|
||||
# 4. GLM-OCR postprocessing: clean, format, merge, bullets
|
||||
markdown_content = self._formatter.process(region_dicts)
|
||||
|
||||
# 5. LaTeX math error correction (our existing pipeline)
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
|
||||
# 6. Format conversion
|
||||
latex, mathml, mml = "", "", ""
|
||||
if markdown_content and self.converter:
|
||||
fmt = self.converter.convert_to_formats(markdown_content)
|
||||
latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml
|
||||
|
||||
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mineru_service = MineruOCRService()
|
||||
image = cv2.imread("test/formula2.jpg")
|
||||
|
||||
@@ -29,6 +29,7 @@ dependencies = [
|
||||
"safetensors",
|
||||
"lxml>=5.0.0",
|
||||
"openai",
|
||||
"wordfreq",
|
||||
]
|
||||
|
||||
# [tool.uv.sources]
|
||||
|
||||
98
tests/api/v1/endpoints/test_image_endpoint.py
Normal file
98
tests/api/v1/endpoints/test_image_endpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints.image import router
|
||||
from app.core.dependencies import get_glmocr_endtoend_service, get_image_processor
|
||||
|
||||
|
||||
class _FakeImageProcessor:
|
||||
def preprocess(self, image_url=None, image_base64=None):
|
||||
return np.zeros((8, 8, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
class _FakeOCRService:
|
||||
def __init__(self, result=None, error=None):
|
||||
self._result = result or {"markdown": "md", "latex": "tex", "mathml": "mml", "mml": "xml"}
|
||||
self._error = error
|
||||
|
||||
def recognize(self, image):
|
||||
if self._error:
|
||||
raise self._error
|
||||
return self._result
|
||||
|
||||
|
||||
def _build_client(image_processor=None, ocr_service=None):
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_image_processor] = lambda: image_processor or _FakeImageProcessor()
|
||||
app.dependency_overrides[get_glmocr_endtoend_service] = lambda: ocr_service or _FakeOCRService()
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
|
||||
client = _build_client()
|
||||
|
||||
missing = client.post("/ocr", json={})
|
||||
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
|
||||
|
||||
assert missing.status_code == 422
|
||||
assert both.status_code == 422
|
||||
|
||||
|
||||
def test_image_endpoint_returns_503_for_runtime_error():
|
||||
client = _build_client(ocr_service=_FakeOCRService(error=RuntimeError("backend unavailable")))
|
||||
|
||||
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json()["detail"] == "backend unavailable"
|
||||
|
||||
|
||||
def test_image_endpoint_returns_500_for_unexpected_error():
|
||||
client = _build_client(ocr_service=_FakeOCRService(error=ValueError("boom")))
|
||||
|
||||
response = client.post("/ocr", json={"image_url": "https://example.com/a.png"})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Internal server error"
|
||||
|
||||
|
||||
def test_image_endpoint_returns_ocr_payload():
|
||||
client = _build_client()
|
||||
|
||||
response = client.post("/ocr", json={"image_base64": "ZmFrZQ=="})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"latex": "tex",
|
||||
"markdown": "md",
|
||||
"mathml": "mml",
|
||||
"mml": "xml",
|
||||
"layout_info": {"regions": [], "MixedRecognition": False},
|
||||
"recognition_mode": "",
|
||||
}
|
||||
|
||||
|
||||
def test_image_endpoint_real_e2e_with_env_services():
|
||||
from app.main import app
|
||||
|
||||
image_url = (
|
||||
"https://static.texpixel.com/formula/012dab3e-fb31-4ecd-90fc-6957458ee309.png"
|
||||
"?Expires=1773049821&OSSAccessKeyId=TMP.3KnrJUz7aXHoU9rLTAih4MAyPGd9zyGRHiqg9AyH6TY6NKtzqT2yr4qo7Vwf8fMRFCBrWXiCFrbBwC3vn7U6mspV2NeU1K"
|
||||
"&Signature=oynhP0OLIgFI0Sv3z2CWeHPT2Ck%3D"
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/doc_process/v1/image/ocr",
|
||||
json={"image_url": image_url},
|
||||
headers={"x-request-id": "test-e2e"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
payload = response.json()
|
||||
assert isinstance(payload["markdown"], str)
|
||||
assert payload["markdown"].strip()
|
||||
assert set(payload) >= {"markdown", "latex", "mathml", "mml"}
|
||||
10
tests/core/test_dependencies.py
Normal file
10
tests/core/test_dependencies.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from app.core import dependencies
|
||||
|
||||
|
||||
def test_get_glmocr_endtoend_service_raises_when_layout_detector_missing(monkeypatch):
|
||||
monkeypatch.setattr(dependencies, "_layout_detector", None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Layout detector not initialized"):
|
||||
dependencies.get_glmocr_endtoend_service()
|
||||
31
tests/schemas/test_image.py
Normal file
31
tests/schemas/test_image.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from app.schemas.image import ImageOCRRequest, LayoutRegion
|
||||
|
||||
|
||||
def test_layout_region_native_label_defaults_to_empty_string():
|
||||
region = LayoutRegion(
|
||||
type="text",
|
||||
bbox=[0, 0, 10, 10],
|
||||
confidence=0.9,
|
||||
score=0.9,
|
||||
)
|
||||
|
||||
assert region.native_label == ""
|
||||
|
||||
|
||||
def test_layout_region_exposes_native_label_when_provided():
|
||||
region = LayoutRegion(
|
||||
type="text",
|
||||
native_label="doc_title",
|
||||
bbox=[0, 0, 10, 10],
|
||||
confidence=0.9,
|
||||
score=0.9,
|
||||
)
|
||||
|
||||
assert region.native_label == "doc_title"
|
||||
|
||||
|
||||
def test_image_ocr_request_requires_exactly_one_input():
|
||||
request = ImageOCRRequest(image_url="https://example.com/test.png")
|
||||
|
||||
assert request.image_url == "https://example.com/test.png"
|
||||
assert request.image_base64 is None
|
||||
199
tests/services/test_glm_postprocess.py
Normal file
199
tests/services/test_glm_postprocess.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from app.services.glm_postprocess import (
|
||||
GLMResultFormatter,
|
||||
clean_formula_number,
|
||||
clean_repeated_content,
|
||||
find_consecutive_repeat,
|
||||
)
|
||||
|
||||
|
||||
def test_find_consecutive_repeat_truncates_when_threshold_met():
|
||||
repeated = "abcdefghij" * 10 + "tail"
|
||||
|
||||
assert find_consecutive_repeat(repeated) == "abcdefghij"
|
||||
|
||||
|
||||
def test_find_consecutive_repeat_returns_none_when_below_threshold():
|
||||
assert find_consecutive_repeat("abcdefghij" * 9) is None
|
||||
|
||||
|
||||
def test_clean_repeated_content_handles_consecutive_and_line_level_repeats():
|
||||
assert clean_repeated_content("abcdefghij" * 10 + "tail") == "abcdefghij"
|
||||
|
||||
line_repeated = "\n".join(["same line"] * 10 + ["other"])
|
||||
assert clean_repeated_content(line_repeated, line_threshold=10) == "same line\n"
|
||||
|
||||
assert clean_repeated_content("normal text") == "normal text"
|
||||
|
||||
|
||||
def test_clean_formula_number_strips_wrapping_parentheses():
|
||||
assert clean_formula_number("(1)") == "1"
|
||||
assert clean_formula_number("(2.1)") == "2.1"
|
||||
assert clean_formula_number("3") == "3"
|
||||
|
||||
|
||||
def test_clean_content_removes_literal_tabs_and_long_repeat_noise():
|
||||
formatter = GLMResultFormatter()
|
||||
noisy = r"\t\t" + ("·" * 5) + ("abcdefghij" * 205) + r"\t"
|
||||
|
||||
cleaned = formatter._clean_content(noisy)
|
||||
|
||||
assert cleaned.startswith("···")
|
||||
assert cleaned.endswith("abcdefghij")
|
||||
assert r"\t" not in cleaned
|
||||
|
||||
|
||||
def test_format_content_handles_titles_formula_text_and_newlines():
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
assert formatter._format_content("Intro", "text", "doc_title") == "# Intro"
|
||||
assert formatter._format_content("- Section", "text", "paragraph_title") == "## Section"
|
||||
assert formatter._format_content(r"\[x+y\]", "formula", "display_formula") == "$$\nx+y\n$$"
|
||||
assert formatter._format_content("· item\nnext", "text", "text") == "- item\n\nnext"
|
||||
|
||||
|
||||
def test_merge_formula_numbers_merges_before_and_after_formula():
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
before = formatter._merge_formula_numbers(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
|
||||
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
]
|
||||
)
|
||||
after = formatter._merge_formula_numbers(
|
||||
[
|
||||
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
|
||||
]
|
||||
)
|
||||
untouched = formatter._merge_formula_numbers(
|
||||
[{"index": 0, "label": "text", "native_label": "formula_number", "content": "(3)"}]
|
||||
)
|
||||
|
||||
assert before == [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y \\tag{1}\n$$",
|
||||
}
|
||||
]
|
||||
assert after == [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y \\tag{2}\n$$",
|
||||
}
|
||||
]
|
||||
assert untouched == []
|
||||
|
||||
|
||||
def test_merge_text_blocks_joins_hyphenated_words_when_wordfreq_accepts(monkeypatch):
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
|
||||
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 3.0)
|
||||
|
||||
merged = formatter._merge_text_blocks(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "national"},
|
||||
]
|
||||
)
|
||||
|
||||
assert merged == [
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "international"}
|
||||
]
|
||||
|
||||
|
||||
def test_merge_text_blocks_skips_invalid_merge(monkeypatch):
|
||||
formatter = GLMResultFormatter()
|
||||
|
||||
monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True)
|
||||
monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 1.0)
|
||||
|
||||
merged = formatter._merge_text_blocks(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
|
||||
]
|
||||
)
|
||||
|
||||
assert merged == [
|
||||
{"index": 0, "label": "text", "native_label": "text", "content": "inter-"},
|
||||
{"index": 1, "label": "text", "native_label": "text", "content": "National"},
|
||||
]
|
||||
|
||||
|
||||
def test_format_bullet_points_infers_missing_middle_bullet():
|
||||
formatter = GLMResultFormatter()
|
||||
items = [
|
||||
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
|
||||
{"native_label": "text", "content": "second", "bbox_2d": [12, 12, 52, 22]},
|
||||
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
|
||||
]
|
||||
|
||||
formatted = formatter._format_bullet_points(items)
|
||||
|
||||
assert formatted[1]["content"] == "- second"
|
||||
|
||||
|
||||
def test_format_bullet_points_skips_when_bbox_missing():
|
||||
formatter = GLMResultFormatter()
|
||||
items = [
|
||||
{"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]},
|
||||
{"native_label": "text", "content": "second", "bbox_2d": []},
|
||||
{"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]},
|
||||
]
|
||||
|
||||
formatted = formatter._format_bullet_points(items)
|
||||
|
||||
assert formatted[1]["content"] == "second"
|
||||
|
||||
|
||||
def test_process_runs_full_pipeline_and_skips_empty_content():
|
||||
formatter = GLMResultFormatter()
|
||||
regions = [
|
||||
{
|
||||
"index": 0,
|
||||
"label": "text",
|
||||
"native_label": "doc_title",
|
||||
"content": "Doc Title",
|
||||
"bbox_2d": [0, 0, 100, 30],
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"label": "text",
|
||||
"native_label": "formula_number",
|
||||
"content": "(1)",
|
||||
"bbox_2d": [80, 50, 100, 60],
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "x+y",
|
||||
"bbox_2d": [0, 40, 100, 80],
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"label": "figure",
|
||||
"native_label": "image",
|
||||
"content": "figure placeholder",
|
||||
"bbox_2d": [0, 80, 100, 120],
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"label": "text",
|
||||
"native_label": "text",
|
||||
"content": "",
|
||||
"bbox_2d": [0, 120, 100, 150],
|
||||
},
|
||||
]
|
||||
|
||||
output = formatter.process(regions)
|
||||
|
||||
assert "# Doc Title" in output
|
||||
assert "$$\nx+y \\tag{1}\n$$" in output
|
||||
assert "" in output
|
||||
46
tests/services/test_layout_detector.py
Normal file
46
tests/services/test_layout_detector.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
|
||||
|
||||
class _FakePredictor:
|
||||
def __init__(self, boxes):
|
||||
self._boxes = boxes
|
||||
|
||||
def predict(self, image):
|
||||
return [{"boxes": self._boxes}]
|
||||
|
||||
|
||||
def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
||||
raw_boxes = [
|
||||
{"cls_id": 22, "label": "text", "score": 0.95, "coordinate": [0, 0, 100, 100]},
|
||||
{"cls_id": 22, "label": "text", "score": 0.90, "coordinate": [10, 10, 20, 20]},
|
||||
{"cls_id": 6, "label": "doc_title", "score": 0.99, "coordinate": [0, 0, 80, 20]},
|
||||
]
|
||||
|
||||
detector = LayoutDetector.__new__(LayoutDetector)
|
||||
detector._get_layout_detector = lambda: _FakePredictor(raw_boxes)
|
||||
|
||||
calls = {}
|
||||
|
||||
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
|
||||
calls["args"] = {
|
||||
"boxes": boxes,
|
||||
"img_size": img_size,
|
||||
"layout_nms": layout_nms,
|
||||
"layout_unclip_ratio": layout_unclip_ratio,
|
||||
"layout_merge_bboxes_mode": layout_merge_bboxes_mode,
|
||||
}
|
||||
return [boxes[0], boxes[2]]
|
||||
|
||||
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
|
||||
|
||||
image = np.zeros((200, 100, 3), dtype=np.uint8)
|
||||
info = detector.detect(image)
|
||||
|
||||
assert calls["args"]["img_size"] == (100, 200)
|
||||
assert calls["args"]["layout_nms"] is True
|
||||
assert calls["args"]["layout_merge_bboxes_mode"] == "large"
|
||||
assert [region.native_label for region in info.regions] == ["text", "doc_title"]
|
||||
assert [region.type for region in info.regions] == ["text", "text"]
|
||||
assert info.MixedRecognition is True
|
||||
151
tests/services/test_layout_postprocess.py
Normal file
151
tests/services/test_layout_postprocess.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.services.layout_postprocess import (
|
||||
apply_layout_postprocess,
|
||||
check_containment,
|
||||
iou,
|
||||
is_contained,
|
||||
nms,
|
||||
unclip_boxes,
|
||||
)
|
||||
|
||||
|
||||
def _raw_box(cls_id, score, x1, y1, x2, y2, label="text"):
|
||||
return {
|
||||
"cls_id": cls_id,
|
||||
"label": label,
|
||||
"score": score,
|
||||
"coordinate": [x1, y1, x2, y2],
|
||||
}
|
||||
|
||||
|
||||
def test_iou_handles_full_none_and_partial_overlap():
|
||||
assert iou([0, 0, 9, 9], [0, 0, 9, 9]) == 1.0
|
||||
assert iou([0, 0, 9, 9], [20, 20, 29, 29]) == 0.0
|
||||
assert math.isclose(iou([0, 0, 9, 9], [5, 5, 14, 14]), 1 / 7, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_nms_keeps_highest_score_for_same_class_overlap():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.95, 0, 0, 10, 10],
|
||||
[0, 0.80, 1, 1, 11, 11],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
||||
|
||||
assert kept == [0]
|
||||
|
||||
|
||||
def test_nms_keeps_cross_class_overlap_boxes_below_diff_threshold():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.95, 0, 0, 10, 10],
|
||||
[1, 0.90, 1, 1, 11, 11],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
kept = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
||||
|
||||
assert kept == [0, 1]
|
||||
|
||||
|
||||
def test_nms_returns_single_box_index():
|
||||
boxes = np.array([[0, 0.95, 0, 0, 10, 10]], dtype=float)
|
||||
|
||||
assert nms(boxes) == [0]
|
||||
|
||||
|
||||
def test_is_contained_uses_overlap_threshold():
|
||||
outer = [0, 0.9, 0, 0, 10, 10]
|
||||
inner = [0, 0.9, 2, 2, 8, 8]
|
||||
partial = [0, 0.9, 6, 6, 12, 12]
|
||||
|
||||
assert is_contained(inner, outer) is True
|
||||
assert is_contained(partial, outer) is False
|
||||
assert is_contained(partial, outer, overlap_threshold=0.3) is True
|
||||
|
||||
|
||||
def test_check_containment_respects_preserve_class_ids():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.9, 0, 0, 100, 100],
|
||||
[1, 0.8, 10, 10, 30, 30],
|
||||
[2, 0.7, 15, 15, 25, 25],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
contains_other, contained_by_other = check_containment(boxes, preserve_cls_ids={1})
|
||||
|
||||
assert contains_other.tolist() == [1, 1, 0]
|
||||
assert contained_by_other.tolist() == [0, 0, 1]
|
||||
|
||||
|
||||
def test_unclip_boxes_supports_scalar_tuple_dict_and_none():
|
||||
boxes = np.array(
|
||||
[
|
||||
[0, 0.9, 10, 10, 20, 20],
|
||||
[1, 0.8, 30, 30, 50, 40],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
scalar = unclip_boxes(boxes, 2.0)
|
||||
assert scalar[:, 2:6].tolist() == [[5.0, 5.0, 25.0, 25.0], [20.0, 25.0, 60.0, 45.0]]
|
||||
|
||||
tuple_ratio = unclip_boxes(boxes, (2.0, 3.0))
|
||||
assert tuple_ratio[:, 2:6].tolist() == [[5.0, 0.0, 25.0, 30.0], [20.0, 20.0, 60.0, 50.0]]
|
||||
|
||||
per_class = unclip_boxes(boxes, {1: (1.5, 2.0)})
|
||||
assert per_class[:, 2:6].tolist() == [[10.0, 10.0, 20.0, 20.0], [25.0, 25.0, 55.0, 45.0]]
|
||||
|
||||
assert np.array_equal(unclip_boxes(boxes, None), boxes)
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_large_mode_removes_contained_small_box():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
|
||||
_raw_box(0, 0.90, 10, 10, 20, 20, "text"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
|
||||
|
||||
assert [box["coordinate"] for box in result] == [[0, 0, 100, 100]]
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_preserves_contained_image_like_boxes():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, 0, 0, 100, 100, "text"),
|
||||
_raw_box(1, 0.90, 10, 10, 20, 20, "image"),
|
||||
_raw_box(2, 0.90, 25, 25, 35, 35, "seal"),
|
||||
_raw_box(3, 0.90, 40, 40, 50, 50, "chart"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large")
|
||||
|
||||
assert {box["label"] for box in result} == {"text", "image", "seal", "chart"}
|
||||
|
||||
|
||||
def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image():
|
||||
boxes = [
|
||||
_raw_box(0, 0.95, -10, -5, 40, 50, "text"),
|
||||
_raw_box(1, 0.90, 10, 10, 10, 50, "text"),
|
||||
_raw_box(2, 0.85, 0, 0, 100, 90, "image"),
|
||||
]
|
||||
|
||||
result = apply_layout_postprocess(
|
||||
boxes,
|
||||
img_size=(100, 90),
|
||||
layout_nms=False,
|
||||
layout_merge_bboxes_mode=None,
|
||||
)
|
||||
|
||||
assert result == [
|
||||
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
|
||||
]
|
||||
124
tests/services/test_ocr_service.py
Normal file
124
tests/services/test_ocr_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
|
||||
|
||||
class _FakeConverter:
|
||||
def convert_to_formats(self, markdown):
|
||||
return SimpleNamespace(
|
||||
latex=f"LATEX::{markdown}",
|
||||
mathml=f"MATHML::{markdown}",
|
||||
mml=f"MML::{markdown}",
|
||||
)
|
||||
|
||||
|
||||
class _FakeImageProcessor:
|
||||
def add_padding(self, image):
|
||||
return image
|
||||
|
||||
|
||||
class _FakeLayoutDetector:
|
||||
def __init__(self, regions):
|
||||
self._regions = regions
|
||||
|
||||
def detect(self, image):
|
||||
return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions))
|
||||
|
||||
|
||||
def _build_service(regions=None):
|
||||
return GLMOCREndToEndService(
|
||||
vl_server_url="http://127.0.0.1:8002/v1",
|
||||
image_processor=_FakeImageProcessor(),
|
||||
converter=_FakeConverter(),
|
||||
layout_detector=_FakeLayoutDetector(regions or []),
|
||||
max_workers=2,
|
||||
)
|
||||
|
||||
|
||||
def test_encode_region_returns_decodable_base64_jpeg():
|
||||
service = _build_service()
|
||||
image = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
image[:, :] = [0, 128, 255]
|
||||
|
||||
encoded = service._encode_region(image)
|
||||
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
assert decoded.shape[:2] == image.shape[:2]
|
||||
|
||||
|
||||
def test_call_vllm_builds_messages_and_returns_content():
|
||||
service = _build_service()
|
||||
captured = {}
|
||||
|
||||
def create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))]
|
||||
)
|
||||
|
||||
service.openai_client = SimpleNamespace(
|
||||
chat=SimpleNamespace(completions=SimpleNamespace(create=create))
|
||||
)
|
||||
|
||||
result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:")
|
||||
|
||||
assert result == "recognized content"
|
||||
assert captured["model"] == "glm-ocr"
|
||||
assert captured["max_tokens"] == 1024
|
||||
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||||
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
||||
|
||||
|
||||
def test_normalize_bbox_scales_coordinates_to_1000():
|
||||
service = _build_service()
|
||||
|
||||
assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000]
|
||||
assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750]
|
||||
|
||||
|
||||
def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
||||
service = _build_service(regions=[])
|
||||
image = np.zeros((20, 30, 3), dtype=np.uint8)
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text")
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert result["markdown"] == "raw text"
|
||||
assert result["latex"] == "LATEX::raw text"
|
||||
assert result["mathml"] == "MATHML::raw text"
|
||||
assert result["mml"] == "MML::raw text"
|
||||
|
||||
|
||||
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
||||
regions = [
|
||||
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
|
||||
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
|
||||
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
|
||||
]
|
||||
service = _build_service(regions=regions)
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_call_vllm(cropped, prompt):
|
||||
calls.append(prompt)
|
||||
if prompt == "Text Recognition:":
|
||||
return "Title"
|
||||
return "x + y"
|
||||
|
||||
monkeypatch.setattr(service, "_call_vllm", fake_call_vllm)
|
||||
|
||||
result = service.recognize(image)
|
||||
|
||||
assert calls == ["Text Recognition:", "Formula Recognition:"]
|
||||
assert result["markdown"] == "# Title\n\n$$\nx + y\n$$"
|
||||
assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$"
|
||||
assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"
|
||||
Reference in New Issue
Block a user