From 6dfaf9668bb8d6be11853cc599a36f3241d2e196 Mon Sep 17 00:00:00 2001 From: liuyuanchuang Date: Mon, 9 Mar 2026 16:51:06 +0800 Subject: [PATCH] feat add glm-ocr core --- .claude/settings.local.json | 14 + app/api/v1/endpoints/image.py | 93 +--- app/core/config.py | 13 +- app/core/dependencies.py | 35 +- app/schemas/image.py | 1 + app/services/glm_postprocess.py | 412 ++++++++++++++++++ app/services/layout_detector.py | 15 +- app/services/layout_postprocess.py | 343 +++++++++++++++ app/services/ocr_service.py | 241 +++++++++- pyproject.toml | 1 + tests/api/v1/endpoints/test_image_endpoint.py | 98 +++++ tests/core/test_dependencies.py | 10 + tests/schemas/test_image.py | 31 ++ tests/services/test_glm_postprocess.py | 199 +++++++++ tests/services/test_layout_detector.py | 46 ++ tests/services/test_layout_postprocess.py | 151 +++++++ tests/services/test_ocr_service.py | 124 ++++++ 17 files changed, 1687 insertions(+), 140 deletions(-) create mode 100644 .claude/settings.local.json create mode 100644 app/services/glm_postprocess.py create mode 100644 app/services/layout_postprocess.py create mode 100644 tests/api/v1/endpoints/test_image_endpoint.py create mode 100644 tests/core/test_dependencies.py create mode 100644 tests/schemas/test_image.py create mode 100644 tests/services/test_glm_postprocess.py create mode 100644 tests/services/test_layout_detector.py create mode 100644 tests/services/test_layout_postprocess.py create mode 100644 tests/services/test_ocr_service.py diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..9d3ce31 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,14 @@ +{ + "permissions": { + "allow": [ + "WebFetch(domain:deepwiki.com)", + "WebFetch(domain:github.com)", + "Read(//private/tmp/**)", + "Bash(gh api repos/zai-org/GLM-OCR/contents/glmocr --jq '.[].name')", + "WebFetch(domain:raw.githubusercontent.com)", + "Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")", + "Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)", + "Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)" + ] + } +} diff --git a/app/api/v1/endpoints/image.py b/app/api/v1/endpoints/image.py index b992009..1c55fd6 100644 --- a/app/api/v1/endpoints/image.py +++ b/app/api/v1/endpoints/image.py @@ -2,26 +2,17 @@ import time import uuid -import cv2 -from io import BytesIO from fastapi import APIRouter, Depends, HTTPException, Request, Response from app.core.dependencies import ( get_image_processor, - get_layout_detector, - get_ocr_service, - get_mineru_ocr_service, - get_glmocr_service, + get_glmocr_endtoend_service, ) -from app.core.config import get_settings from app.core.logging_config import get_logger, RequestIDAdapter from app.schemas.image import ImageOCRRequest, ImageOCRResponse from app.services.image_processor import ImageProcessor -from app.services.layout_detector import LayoutDetector -from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService - -settings = get_settings() +from app.services.ocr_service import GLMOCREndToEndService router = APIRouter() logger = get_logger() @@ -33,100 +24,38 @@ async def process_image_ocr( http_request: Request, response: Response, image_processor: ImageProcessor = Depends(get_image_processor), - layout_detector: LayoutDetector = Depends(get_layout_detector), - mineru_service: MineruOCRService = Depends(get_mineru_ocr_service), - paddle_service: OCRService = Depends(get_ocr_service), - glmocr_service: GLMOCRService = Depends(get_glmocr_service), + glmocr_service: GLMOCREndToEndService = Depends(get_glmocr_endtoend_service), ) -> ImageOCRResponse: """Process an image and extract content as LaTeX, Markdown, and MathML. The processing pipeline: - 1. Load and preprocess image (add 30% whitespace padding) - 2. Detect layout using DocLayout-YOLO - 3. Based on layout: - - If plain text exists: use PP-DocLayoutV2 for mixed recognition - - Otherwise: use PaddleOCR-VL with formula prompt - 4. Convert output to LaTeX, Markdown, and MathML formats + 1. Load and preprocess image + 2. Detect layout regions using PP-DocLayoutV3 + 3. Crop each region and recognize with GLM-OCR via vLLM (task-specific prompts) + 4. Aggregate region results into Markdown + 5. Convert to LaTeX, Markdown, and MathML formats Note: OMML conversion is not included due to performance overhead. Use the /convert/latex-to-omml endpoint to convert LaTeX to OMML separately. """ - # Get or generate request ID request_id = http_request.headers.get("x-request-id", str(uuid.uuid4())) response.headers["x-request-id"] = request_id - # Create logger adapter with request_id log = RequestIDAdapter(logger, {"request_id": request_id}) log.request_id = request_id try: log.info("Starting image OCR processing") + start = time.time() - # Preprocess image (load only, no padding yet) - preprocess_start = time.time() image = image_processor.preprocess( image_url=request.image_url, image_base64=request.image_base64, ) - # Apply padding only for layout detection - processed_image = image - if image_processor and settings.is_padding: - processed_image = image_processor.add_padding(image) + ocr_result = glmocr_service.recognize(image) - preprocess_time = time.time() - preprocess_start - log.debug(f"Image loading completed in {preprocess_time:.3f}s") - - # Layout detection (using padded image if padding is enabled) - layout_start = time.time() - layout_info = layout_detector.detect(processed_image) - layout_time = time.time() - layout_start - log.info(f"Layout detection completed in {layout_time:.3f}s") - - # OCR recognition (use original image without padding) - ocr_start = time.time() - if layout_info.MixedRecognition: - recognition_method = "MixedRecognition (MinerU)" - log.info(f"Using {recognition_method}") - - # Convert original image (without padding) to bytes - success, encoded_image = cv2.imencode(".png", image) - if not success: - raise RuntimeError("Failed to encode image") - - image_bytes = BytesIO(encoded_image.tobytes()) - image_bytes.seek(0) # Ensure position is at the beginning - ocr_result = mineru_service.recognize(image_bytes) - else: - recognition_method = "FormulaOnly (GLMOCR)" - log.info(f"Using {recognition_method}") - - # Try GLM-OCR first, fallback to MinerU if token limit exceeded - try: - ocr_result = glmocr_service.recognize(image) - except Exception as e: - error_msg = str(e) - # Check if error is due to token limit (max_model_len exceeded) - if "max_model_len" in error_msg or "decoder prompt" in error_msg or "BadRequestError" in error_msg: - log.warning(f"GLM-OCR failed due to token limit: {error_msg}") - log.info("Falling back to MinerU for recognition") - recognition_method = "FormulaOnly (MinerU fallback)" - - # Convert original image to bytes for MinerU - success, encoded_image = cv2.imencode(".png", image) - if not success: - raise RuntimeError("Failed to encode image") - - image_bytes = BytesIO(encoded_image.tobytes()) - image_bytes.seek(0) - ocr_result = mineru_service.recognize(image_bytes) - else: - # Re-raise other errors - raise - ocr_time = time.time() - ocr_start - - total_time = time.time() - preprocess_start - log.info(f"OCR processing completed - Method: {recognition_method}, " f"Layout time: {layout_time:.3f}s, OCR time: {ocr_time:.3f}s, " f"Total time: {total_time:.3f}s") + log.info(f"OCR completed in {time.time() - start:.3f}s") except RuntimeError as e: log.error(f"OCR processing failed: {str(e)}", exc_info=True) diff --git a/app/core/config.py b/app/core/config.py index a5391a7..e014cef 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -3,9 +3,8 @@ from functools import lru_cache from pathlib import Path -from pydantic_settings import BaseSettings, SettingsConfigDict import torch -from typing import Optional +from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): @@ -48,21 +47,25 @@ class Settings(BaseSettings): is_padding: bool = True padding_ratio: float = 0.1 + max_tokens: int = 4096 + # Model Paths - pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3" + pp_doclayout_model_dir: str | None = ( + "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3" + ) # Image Processing max_image_size_mb: int = 10 image_padding_ratio: float = 0.1 # 10% on each side = 20% total expansion - device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # cuda:0 or cpu + device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Server Settings host: str = "0.0.0.0" port: int = 8053 # Logging Settings - log_dir: Optional[str] = None # Defaults to /app/logs in container or ./logs locally + log_dir: str | None = None # Defaults to /app/logs in container or ./logs locally log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL @property diff --git a/app/core/dependencies.py b/app/core/dependencies.py index 494538b..50ddbdb 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -2,7 +2,7 @@ from app.services.image_processor import ImageProcessor from app.services.layout_detector import LayoutDetector -from app.services.ocr_service import OCRService, MineruOCRService, GLMOCRService +from app.services.ocr_service import GLMOCREndToEndService from app.services.converter import Converter from app.core.config import get_settings @@ -31,40 +31,17 @@ def get_image_processor() -> ImageProcessor: return ImageProcessor() -def get_ocr_service() -> OCRService: - """Get an OCR service instance.""" - return OCRService( - vl_server_url=get_settings().paddleocr_vl_url, - layout_detector=get_layout_detector(), - image_processor=get_image_processor(), - converter=get_converter(), - ) - - def get_converter() -> Converter: """Get a DOCX converter instance.""" return Converter() -def get_mineru_ocr_service() -> MineruOCRService: - """Get a MinerOCR service instance.""" +def get_glmocr_endtoend_service() -> GLMOCREndToEndService: + """Get end-to-end GLM-OCR service (layout detection + per-region OCR).""" settings = get_settings() - api_url = getattr(settings, "miner_ocr_api_url", "http://127.0.0.1:8000/file_parse") - glm_ocr_url = getattr(settings, "glm_ocr_url", "http://localhost:8002/v1") - return MineruOCRService( - api_url=api_url, - converter=get_converter(), - image_processor=get_image_processor(), - glm_ocr_url=glm_ocr_url, - ) - - -def get_glmocr_service() -> GLMOCRService: - """Get a GLM OCR service instance.""" - settings = get_settings() - glm_ocr_url = getattr(settings, "glm_ocr_url", "http://127.0.0.1:8002/v1") - return GLMOCRService( - vl_server_url=glm_ocr_url, + return GLMOCREndToEndService( + vl_server_url=settings.glm_ocr_url, image_processor=get_image_processor(), converter=get_converter(), + layout_detector=get_layout_detector(), ) diff --git a/app/schemas/image.py b/app/schemas/image.py index 3b46a18..f0d8f37 100644 --- a/app/schemas/image.py +++ b/app/schemas/image.py @@ -7,6 +7,7 @@ class LayoutRegion(BaseModel): """A detected layout region in the document.""" type: str = Field(..., description="Region type: text, formula, table, figure") + native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)") bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]") confidence: float = Field(..., description="Detection confidence score") score: float = Field(..., description="Detection score") diff --git a/app/services/glm_postprocess.py b/app/services/glm_postprocess.py new file mode 100644 index 0000000..d84e589 --- /dev/null +++ b/app/services/glm_postprocess.py @@ -0,0 +1,412 @@ +"""GLM-OCR postprocessing logic adapted for this project. + +Ported from glm-ocr/glmocr/postprocess/result_formatter.py and +glm-ocr/glmocr/utils/result_postprocess_utils.py. + +Covers: +- Repeated-content / hallucination detection +- Per-region content cleaning and formatting (titles, bullets, formulas) +- formula_number merging (→ \\tag{}) +- Hyphenated text-block merging (via wordfreq) +- Missing bullet-point detection +""" + +from __future__ import annotations + +import re +import json +from collections import Counter +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +try: + from wordfreq import zipf_frequency + + _WORDFREQ_AVAILABLE = True +except ImportError: + _WORDFREQ_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# result_postprocess_utils (ported) +# --------------------------------------------------------------------------- + + +def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]: + """Detect and truncate a consecutively-repeated pattern. + + Returns the string with the repeat removed, or None if not found. + """ + n = len(s) + if n < min_unit_len * min_repeats: + return None + + max_unit_len = n // min_repeats + if max_unit_len < min_unit_len: + return None + + pattern = re.compile( + r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}", + re.DOTALL, + ) + match = pattern.search(s) + if match: + return s[: match.start()] + match.group(1) + return None + + +def clean_repeated_content( + content: str, + min_len: int = 10, + min_repeats: int = 10, + line_threshold: int = 10, +) -> str: + """Remove hallucination-style repeated content (consecutive or line-level).""" + stripped = content.strip() + if not stripped: + return content + + # 1. Consecutive repeat (multi-line aware) + if len(stripped) > min_len * min_repeats: + result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats) + if result is not None: + return result + + # 2. Line-level repeat + lines = [line.strip() for line in content.split("\n") if line.strip()] + total_lines = len(lines) + if total_lines >= line_threshold and lines: + common, count = Counter(lines).most_common(1)[0] + if count >= line_threshold and (count / total_lines) >= 0.8: + for i, line in enumerate(lines): + if line == common: + consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common) + if consecutive >= 3: + original_lines = content.split("\n") + non_empty_count = 0 + for idx, orig_line in enumerate(original_lines): + if orig_line.strip(): + non_empty_count += 1 + if non_empty_count == i + 1: + return "\n".join(original_lines[: idx + 1]) + break + return content + + +def clean_formula_number(number_content: str) -> str: + """Strip parentheses from a formula number string, e.g. '(1)' → '1'.""" + s = number_content.strip() + if s.startswith("(") and s.endswith(")"): + return s[1:-1] + if s.startswith("(") and s.endswith(")"): + return s[1:-1] + return s + + +# --------------------------------------------------------------------------- +# GLMResultFormatter +# --------------------------------------------------------------------------- + +# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping) +_LABEL_TO_CATEGORY: Dict[str, str] = { + # text + "abstract": "text", + "algorithm": "text", + "content": "text", + "doc_title": "text", + "figure_title": "text", + "paragraph_title": "text", + "reference_content": "text", + "text": "text", + "vertical_text": "text", + "vision_footnote": "text", + "seal": "text", + "formula_number": "text", + # table + "table": "table", + # formula + "display_formula": "formula", + "inline_formula": "formula", + # image (skip OCR) + "chart": "image", + "image": "image", +} + + +class GLMResultFormatter: + """Port of GLM-OCR's ResultFormatter for use in our pipeline. + + Accepts a list of region dicts (each with label, native_label, content, + bbox_2d) and returns a final Markdown string. + """ + + # ------------------------------------------------------------------ # + # Public entry-point + # ------------------------------------------------------------------ # + + def process(self, regions: List[Dict[str, Any]]) -> str: + """Run the full postprocessing pipeline and return Markdown. + + Args: + regions: List of dicts with keys: + - index (int) reading order from layout detection + - label (str) mapped category: text/formula/table/figure + - native_label (str) raw PP-DocLayout label (e.g. doc_title) + - content (str) raw OCR output from vLLM + - bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords + + Returns: + Markdown string. + """ + # Sort by reading order + items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0)) + + # Per-region cleaning + formatting + processed: List[Dict] = [] + for item in items: + item["native_label"] = item.get("native_label", item.get("label", "text")) + item["label"] = self._map_label(item.get("label", "text"), item["native_label"]) + + item["content"] = self._format_content( + item.get("content") or "", + item["label"], + item["native_label"], + ) + if not (item.get("content") or "").strip(): + continue + processed.append(item) + + # Re-index + for i, item in enumerate(processed): + item["index"] = i + + # Structural merges + processed = self._merge_formula_numbers(processed) + processed = self._merge_text_blocks(processed) + processed = self._format_bullet_points(processed) + + # Assemble Markdown + parts: List[str] = [] + for item in processed: + content = item.get("content") or "" + if item["label"] == "image": + parts.append(f"![](bbox={item.get('bbox_2d', [])})") + elif content.strip(): + parts.append(content) + + return "\n\n".join(parts) + + # ------------------------------------------------------------------ # + # Label mapping + # ------------------------------------------------------------------ # + + def _map_label(self, label: str, native_label: str) -> str: + return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text")) + + # ------------------------------------------------------------------ # + # Content cleaning + # ------------------------------------------------------------------ # + + def _clean_content(self, content: str) -> str: + """Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats.""" + if content is None: + return "" + + content = re.sub(r"^(\\t)+", "", content).lstrip() + content = re.sub(r"(\\t)+$", "", content).rstrip() + + content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content) + content = re.sub(r"(·)\1{2,}", r"\1\1\1", content) + content = re.sub(r"(_)\1{2,}", r"\1\1\1", content) + content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content) + + if len(content) >= 2048: + content = clean_repeated_content(content) + + return content.strip() + + # ------------------------------------------------------------------ # + # Per-region content formatting + # ------------------------------------------------------------------ # + + def _format_content(self, content: Any, label: str, native_label: str) -> str: + """Clean and format a single region's content.""" + if content is None: + return "" + + content = self._clean_content(str(content)) + + # Heading formatting + if native_label == "doc_title": + content = re.sub(r"^#+\s*", "", content) + content = "# " + content + elif native_label == "paragraph_title": + if content.startswith("- ") or content.startswith("* "): + content = content[2:].lstrip() + content = re.sub(r"^#+\s*", "", content) + content = "## " + content.lstrip() + + # Formula wrapping + if label == "formula": + content = content.strip() + for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]: + if content.startswith(s) and content.endswith(e): + content = content[len(s) : -len(e)].strip() + break + content = "$$\n" + content + "\n$$" + + # Text formatting + if label == "text": + if content.startswith("·") or content.startswith("•") or content.startswith("* "): + content = "- " + content[1:].lstrip() + + match = re.match(r"^(\(|\()(\d+|[A-Za-z])(\)|\))(.*)$", content) + if match: + _, symbol, _, rest = match.groups() + content = f"({symbol}) {rest.lstrip()}" + + match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\))(.*)$", content) + if match: + symbol, sep, rest = match.groups() + sep = ")" if sep == ")" else sep + content = f"{symbol}{sep} {rest.lstrip()}" + + # Single newline → double newline + content = re.sub(r"(? List[Dict]: + """Merge formula_number region into adjacent formula with \\tag{}.""" + if not items: + return items + + merged: List[Dict] = [] + skip: set = set() + + for i, block in enumerate(items): + if i in skip: + continue + + native = block.get("native_label", "") + + # Case 1: formula_number then formula + if native == "formula_number": + if i + 1 < len(items) and items[i + 1].get("label") == "formula": + num_clean = clean_formula_number(block.get("content", "").strip()) + formula_content = items[i + 1].get("content", "") + merged_block = deepcopy(items[i + 1]) + if formula_content.endswith("\n$$"): + merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + merged.append(merged_block) + skip.add(i + 1) + continue # always skip the formula_number block itself + + # Case 2: formula then formula_number + if block.get("label") == "formula": + if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number": + num_clean = clean_formula_number(items[i + 1].get("content", "").strip()) + formula_content = block.get("content", "") + merged_block = deepcopy(block) + if formula_content.endswith("\n$$"): + merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + merged.append(merged_block) + skip.add(i + 1) + continue + + merged.append(block) + + for i, block in enumerate(merged): + block["index"] = i + return merged + + def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]: + """Merge hyphenated text blocks when the combined word is valid (wordfreq).""" + if not items or not _WORDFREQ_AVAILABLE: + return items + + merged: List[Dict] = [] + skip: set = set() + + for i, block in enumerate(items): + if i in skip: + continue + if block.get("label") != "text": + merged.append(block) + continue + + content = block.get("content", "") + if not isinstance(content, str) or not content.rstrip().endswith("-"): + merged.append(block) + continue + + content_stripped = content.rstrip() + did_merge = False + for j in range(i + 1, len(items)): + if items[j].get("label") != "text": + continue + next_content = items[j].get("content", "") + if not isinstance(next_content, str): + continue + next_stripped = next_content.lstrip() + if next_stripped and next_stripped[0].islower(): + words_before = content_stripped[:-1].split() + next_words = next_stripped.split() + if words_before and next_words: + merged_word = words_before[-1] + next_words[0] + if zipf_frequency(merged_word.lower(), "en") >= 2.5: + merged_block = deepcopy(block) + merged_block["content"] = content_stripped[:-1] + next_content.lstrip() + merged.append(merged_block) + skip.add(j) + did_merge = True + break + + if not did_merge: + merged.append(block) + + for i, block in enumerate(merged): + block["index"] = i + return merged + + def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]: + """Add missing bullet prefix when a text block is sandwiched between two bullet items.""" + if len(items) < 3: + return items + + for i in range(1, len(items) - 1): + cur = items[i] + prev = items[i - 1] + nxt = items[i + 1] + + if cur.get("native_label") != "text": + continue + if prev.get("native_label") != "text" or nxt.get("native_label") != "text": + continue + + cur_content = cur.get("content", "") + if cur_content.startswith("- "): + continue + + prev_content = prev.get("content", "") + nxt_content = nxt.get("content", "") + if not (prev_content.startswith("- ") and nxt_content.startswith("- ")): + continue + + cur_bbox = cur.get("bbox_2d", []) + prev_bbox = prev.get("bbox_2d", []) + nxt_bbox = nxt.get("bbox_2d", []) + if not (cur_bbox and prev_bbox and nxt_bbox): + continue + + if ( + abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold + and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold + ): + cur["content"] = "- " + cur_content + + return items diff --git a/app/services/layout_detector.py b/app/services/layout_detector.py index 0cdf75b..84b0647 100644 --- a/app/services/layout_detector.py +++ b/app/services/layout_detector.py @@ -1,9 +1,10 @@ -"""PP-DocLayoutV2 wrapper for document layout detection.""" +"""PP-DocLayoutV3 wrapper for document layout detection.""" import numpy as np from app.schemas.image import LayoutInfo, LayoutRegion from app.core.config import get_settings +from app.services.layout_postprocess import apply_layout_postprocess from paddleocr import LayoutDetection from typing import Optional @@ -116,6 +117,17 @@ class LayoutDetector: else: boxes = [] + # Apply GLM-OCR layout post-processing (NMS, containment, unclip, clamp) + if boxes: + h, w = image.shape[:2] + boxes = apply_layout_postprocess( + boxes, + img_size=(w, h), + layout_nms=True, + layout_unclip_ratio=None, + layout_merge_bboxes_mode="large", + ) + for box in boxes: cls_id = box.get("cls_id") label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other") @@ -128,6 +140,7 @@ class LayoutDetector: regions.append( LayoutRegion( type=region_type, + native_label=label, bbox=coordinate, confidence=score, score=score, diff --git a/app/services/layout_postprocess.py b/app/services/layout_postprocess.py new file mode 100644 index 0000000..2eb9a7a --- /dev/null +++ b/app/services/layout_postprocess.py @@ -0,0 +1,343 @@ +"""Layout post-processing utilities ported from GLM-OCR. + +Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py + +Algorithms applied after PaddleOCR LayoutDetection.predict(): + 1. NMS with dual IoU thresholds (same-class vs cross-class) + 2. Large-image-region filtering (remove image boxes that fill most of the page) + 3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child) + 4. Unclip ratio (optional bbox expansion) + 5. Invalid bbox skipping + +These steps run on top of PaddleOCR's built-in detection to replicate +the quality of the GLM-OCR SDK's layout pipeline. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + + +# --------------------------------------------------------------------------- +# Primitive geometry helpers +# --------------------------------------------------------------------------- + +def iou(box1: List[float], box2: List[float]) -> float: + """Compute IoU of two bounding boxes [x1, y1, x2, y2].""" + x1, y1, x2, y2 = box1 + x1_p, y1_p, x2_p, y2_p = box2 + + x1_i = max(x1, x1_p) + y1_i = max(y1, y1_p) + x2_i = min(x2, x2_p) + y2_i = min(y2, y2_p) + + inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1) + box1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1) + + return inter_area / float(box1_area + box2_area - inter_area) + + +def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool: + """Return True if box1 is contained within box2 (overlap ratio >= threshold). + + box format: [cls_id, score, x1, y1, x2, y2] + """ + _, _, x1, y1, x2, y2 = box1 + _, _, x1_p, y1_p, x2_p, y2_p = box2 + + box1_area = (x2 - x1) * (y2 - y1) + if box1_area <= 0: + return False + + xi1 = max(x1, x1_p) + yi1 = max(y1, y1_p) + xi2 = min(x2, x2_p) + yi2 = min(y2, y2_p) + inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) + + return (inter_area / box1_area) >= overlap_threshold + + +# --------------------------------------------------------------------------- +# NMS +# --------------------------------------------------------------------------- + +def nms( + boxes: np.ndarray, + iou_same: float = 0.6, + iou_diff: float = 0.98, +) -> List[int]: + """NMS with separate IoU thresholds for same-class and cross-class overlaps. + + Args: + boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. + iou_same: Suppression threshold for boxes of the same class. + iou_diff: Suppression threshold for boxes of different classes. + + Returns: + List of kept row indices. + """ + scores = boxes[:, 1] + indices = np.argsort(scores)[::-1].tolist() + selected: List[int] = [] + + while indices: + current = indices[0] + selected.append(current) + current_class = int(boxes[current, 0]) + current_coords = boxes[current, 2:6].tolist() + indices = indices[1:] + + kept = [] + for i in indices: + box_class = int(boxes[i, 0]) + box_coords = boxes[i, 2:6].tolist() + threshold = iou_same if current_class == box_class else iou_diff + if iou(current_coords, box_coords) < threshold: + kept.append(i) + indices = kept + + return selected + + +# --------------------------------------------------------------------------- +# Containment analysis +# --------------------------------------------------------------------------- + +# Labels whose regions should never be removed even when contained in another box +_PRESERVE_LABELS = {"image", "seal", "chart"} + + +def check_containment( + boxes: np.ndarray, + preserve_cls_ids: Optional[set] = None, + category_index: Optional[int] = None, + mode: Optional[str] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Compute containment flags for each box. + + Args: + boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. + preserve_cls_ids: Class IDs that must never be marked as contained. + category_index: If set, apply mode only relative to this class. + mode: 'large' or 'small' (only used with category_index). + + Returns: + (contains_other, contained_by_other): boolean arrays of length N. + """ + n = len(boxes) + contains_other = np.zeros(n, dtype=int) + contained_by_other = np.zeros(n, dtype=int) + + for i in range(n): + for j in range(n): + if i == j: + continue + if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids: + continue + if category_index is not None and mode is not None: + if mode == "large" and int(boxes[j, 0]) == category_index: + if is_contained(boxes[i].tolist(), boxes[j].tolist()): + contained_by_other[i] = 1 + contains_other[j] = 1 + elif mode == "small" and int(boxes[i, 0]) == category_index: + if is_contained(boxes[i].tolist(), boxes[j].tolist()): + contained_by_other[i] = 1 + contains_other[j] = 1 + else: + if is_contained(boxes[i].tolist(), boxes[j].tolist()): + contained_by_other[i] = 1 + contains_other[j] = 1 + + return contains_other, contained_by_other + + +# --------------------------------------------------------------------------- +# Box expansion (unclip) +# --------------------------------------------------------------------------- + +def unclip_boxes( + boxes: np.ndarray, + unclip_ratio: Union[float, Tuple[float, float], Dict, List, None], +) -> np.ndarray: + """Expand bounding boxes by the given ratio. + + Args: + boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. + unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio. + + Returns: + Expanded boxes array. + """ + if unclip_ratio is None: + return boxes + + if isinstance(unclip_ratio, dict): + expanded = [] + for box in boxes: + cls_id = int(box[0]) + if cls_id in unclip_ratio: + w_ratio, h_ratio = unclip_ratio[cls_id] + x1, y1, x2, y2 = box[2], box[3], box[4], box[5] + cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 + nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio + new_box = list(box) + new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2 + new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2 + expanded.append(new_box) + else: + expanded.append(list(box)) + return np.array(expanded) + + # Scalar or tuple + if isinstance(unclip_ratio, (int, float)): + unclip_ratio = (float(unclip_ratio), float(unclip_ratio)) + + w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1] + widths = boxes[:, 4] - boxes[:, 2] + heights = boxes[:, 5] - boxes[:, 3] + cx = boxes[:, 2] + widths / 2 + cy = boxes[:, 3] + heights / 2 + nw, nh = widths * w_ratio, heights * h_ratio + expanded = boxes.copy().astype(float) + expanded[:, 2] = cx - nw / 2 + expanded[:, 3] = cy - nh / 2 + expanded[:, 4] = cx + nw / 2 + expanded[:, 5] = cy + nh / 2 + return expanded + + +# --------------------------------------------------------------------------- +# Main entry-point +# --------------------------------------------------------------------------- + +def apply_layout_postprocess( + boxes: List[Dict], + img_size: Tuple[int, int], + layout_nms: bool = True, + layout_unclip_ratio: Union[float, Tuple, Dict, None] = None, + layout_merge_bboxes_mode: Union[str, Dict, None] = "large", +) -> List[Dict]: + """Apply GLM-OCR layout post-processing to PaddleOCR detection results. + + Args: + boxes: PaddleOCR output — list of dicts with keys: + cls_id, label, score, coordinate ([x1, y1, x2, y2]). + img_size: (width, height) of the image. + layout_nms: Apply dual-threshold NMS. + layout_unclip_ratio: Optional bbox expansion ratio. + layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small', + 'union', or per-class dict. + + Returns: + Filtered and ordered list of box dicts in the same PaddleOCR format. + """ + if not boxes: + return boxes + + img_width, img_height = img_size + + # --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- # + arr_rows = [] + for b in boxes: + cls_id = b.get("cls_id", 0) + score = b.get("score", 0.0) + x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0]) + arr_rows.append([cls_id, score, x1, y1, x2, y2]) + boxes_array = np.array(arr_rows, dtype=float) + + all_labels: List[str] = [b.get("label", "") for b in boxes] + + # 1. NMS ---------------------------------------------------------------- # + if layout_nms and len(boxes_array) > 1: + kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98) + boxes_array = boxes_array[kept] + all_labels = [all_labels[k] for k in kept] + + # 2. Filter large image regions ---------------------------------------- # + if len(boxes_array) > 1: + img_area = img_width * img_height + area_thres = 0.82 if img_width > img_height else 0.93 + image_cls_ids = { + int(boxes_array[i, 0]) + for i, lbl in enumerate(all_labels) + if lbl == "image" + } + keep_mask = np.ones(len(boxes_array), dtype=bool) + for i, lbl in enumerate(all_labels): + if lbl == "image": + x1, y1, x2, y2 = boxes_array[i, 2:6] + x1 = max(0.0, x1); y1 = max(0.0, y1) + x2 = min(float(img_width), x2); y2 = min(float(img_height), y2) + if (x2 - x1) * (y2 - y1) > area_thres * img_area: + keep_mask[i] = False + boxes_array = boxes_array[keep_mask] + all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] + + # 3. Containment analysis (merge_bboxes_mode) -------------------------- # + if layout_merge_bboxes_mode and len(boxes_array) > 1: + preserve_cls_ids = { + int(boxes_array[i, 0]) + for i, lbl in enumerate(all_labels) + if lbl in _PRESERVE_LABELS + } + + if isinstance(layout_merge_bboxes_mode, str): + mode = layout_merge_bboxes_mode + if mode in ("large", "small"): + contains_other, contained_by_other = check_containment( + boxes_array, preserve_cls_ids + ) + if mode == "large": + keep_mask = contained_by_other == 0 + else: + keep_mask = (contains_other == 0) | (contained_by_other == 1) + boxes_array = boxes_array[keep_mask] + all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] + + elif isinstance(layout_merge_bboxes_mode, dict): + keep_mask = np.ones(len(boxes_array), dtype=bool) + for category_index, mode in layout_merge_bboxes_mode.items(): + if mode in ("large", "small"): + contains_other, contained_by_other = check_containment( + boxes_array, preserve_cls_ids, int(category_index), mode + ) + if mode == "large": + keep_mask &= contained_by_other == 0 + else: + keep_mask &= (contains_other == 0) | (contained_by_other == 1) + boxes_array = boxes_array[keep_mask] + all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] + + if len(boxes_array) == 0: + return [] + + # 4. Unclip (bbox expansion) ------------------------------------------- # + if layout_unclip_ratio is not None: + boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio) + + # 5. Clamp to image boundaries + skip invalid -------------------------- # + result: List[Dict] = [] + for i, row in enumerate(boxes_array): + cls_id = int(row[0]) + score = float(row[1]) + x1 = max(0.0, min(float(row[2]), img_width)) + y1 = max(0.0, min(float(row[3]), img_height)) + x2 = max(0.0, min(float(row[4]), img_width)) + y2 = max(0.0, min(float(row[5]), img_height)) + + if x1 >= x2 or y1 >= y2: + continue + + result.append({ + "cls_id": cls_id, + "label": all_labels[i], + "score": score, + "coordinate": [int(x1), int(y1), int(x2), int(y2)], + }) + + return result diff --git a/app/services/ocr_service.py b/app/services/ocr_service.py index 1465c93..28de285 100644 --- a/app/services/ocr_service.py +++ b/app/services/ocr_service.py @@ -1,19 +1,23 @@ """PaddleOCR-VL client service for text and formula recognition.""" -import re -import numpy as np -import cv2 -import requests -from io import BytesIO import base64 -from app.core.config import get_settings -from paddleocr import PaddleOCRVL -from typing import Optional -from app.services.layout_detector import LayoutDetector -from app.services.image_processor import ImageProcessor -from app.services.converter import Converter +import re from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from io import BytesIO + +import cv2 +import numpy as np +import requests from openai import OpenAI +from paddleocr import PaddleOCRVL +from PIL import Image as PILImage + +from app.core.config import get_settings +from app.services.converter import Converter +from app.services.glm_postprocess import GLMResultFormatter +from app.services.image_processor import ImageProcessor +from app.services.layout_detector import LayoutDetector settings = get_settings() @@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str: # Strategy: remove spaces before \ and between non-command chars, # but preserve the space after \command when followed by a non-\ char cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd - cleaned = re.sub(r"(? str: + def _recognize_formula_with_paddleocr_vl( + self, image: np.ndarray, prompt: str = "Formula Recognition:" + ) -> str: """Recognize formula using PaddleOCR-VL API. Args: @@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase): image_url = f"data:image/png;base64,{image_base64}" # Call OpenAI-compatible API - messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": prompt}, + ], + } + ] response = self.openai_client.chat.completions.create( model="glm-ocr", @@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase): except Exception as e: raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e - def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str: + def _extract_and_recognize_formulas( + self, markdown_content: str, original_image: np.ndarray + ) -> str: """Extract image references from markdown and recognize formulas. Args: @@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase): } # Make API request - response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30) + response = requests.post( + self.api_url, + files=files, + data=data, + headers={"accept": "application/json"}, + timeout=30, + ) response.raise_for_status() result = response.json() @@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase): markdown_content = result["results"]["image"].get("md_content", "") if "![](images/" in markdown_content: - markdown_content = self._extract_and_recognize_formulas(markdown_content, original_image) + markdown_content = self._extract_and_recognize_formulas( + markdown_content, original_image + ) # Apply postprocessing to fix OCR errors markdown_content = _postprocess_markdown(markdown_content) @@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase): raise RuntimeError(f"Recognition failed: {e}") from e +# Task-specific prompts (from GLM-OCR SDK config.yaml) +_TASK_PROMPTS: dict[str, str] = { + "text": "Text Recognition:", + "formula": "Formula Recognition:", + "table": "Table Recognition:", +} +_DEFAULT_PROMPT = ( + "Recognize the text in the image and output in Markdown format. " + "Preserve the original layout (headings/paragraphs/tables/formulas). " + "Do not fabricate content that does not exist in the image." +) + + +class GLMOCREndToEndService(OCRServiceBase): + """End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR. + + Pipeline: + 1. Add padding (ImageProcessor) + 2. Detect layout regions (LayoutDetector → PP-DocLayoutV3) + 3. Crop each region and call vLLM with a task-specific prompt (parallel) + 4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags + 5. _postprocess_markdown: LaTeX math error correction + 6. Converter: markdown → latex/mathml/mml + + This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed). + """ + + def __init__( + self, + vl_server_url: str, + image_processor: ImageProcessor, + converter: Converter, + layout_detector: LayoutDetector, + max_workers: int = 8, + ): + self.vl_server_url = vl_server_url or settings.glm_ocr_url + self.image_processor = image_processor + self.converter = converter + self.layout_detector = layout_detector + self.max_workers = max_workers + self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600) + self._formatter = GLMResultFormatter() + + def _encode_region(self, image: np.ndarray) -> str: + """Convert BGR numpy array to base64 JPEG string.""" + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + pil_img = PILImage.fromarray(rgb) + buf = BytesIO() + pil_img.save(buf, format="JPEG") + return base64.b64encode(buf.getvalue()).decode("utf-8") + + def _call_vllm(self, image: np.ndarray, prompt: str) -> str: + """Send image + prompt to vLLM and return raw content string.""" + img_b64 = self._encode_region(image) + data_url = f"data:image/jpeg;base64,{img_b64}" + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": prompt}, + ], + } + ] + response = self.openai_client.chat.completions.create( + model="glm-ocr", + messages=messages, + temperature=0.01, + max_tokens=settings.max_tokens, + ) + return response.choices[0].message.content.strip() + + def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]: + """Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords.""" + x1, y1, x2, y2 = bbox + return [ + int(x1 / img_w * 1000), + int(y1 / img_h * 1000), + int(x2 / img_w * 1000), + int(y2 / img_h * 1000), + ] + + def recognize(self, image: np.ndarray) -> dict: + """Full pipeline: padding → layout → per-region OCR → postprocess → markdown. + + Args: + image: Input image as numpy array in BGR format. + + Returns: + Dict with 'markdown', 'latex', 'mathml', 'mml' keys. + """ + # 1. Padding + padded = self.image_processor.add_padding(image) + img_h, img_w = padded.shape[:2] + + # 2. Layout detection + layout_info = self.layout_detector.detect(padded) + + # 3. OCR: per-region (parallel) or full-image fallback + if not layout_info.regions: + raw_content = self._call_vllm(padded, _DEFAULT_PROMPT) + markdown_content = self._formatter._clean_content(raw_content) + else: + # Build task list for non-figure regions + tasks = [] + for idx, region in enumerate(layout_info.regions): + if region.type == "figure": + continue + x1, y1, x2, y2 = (int(c) for c in region.bbox) + cropped = padded[y1:y2, x1:x2] + if cropped.size == 0: + continue + prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT) + tasks.append((idx, region, cropped, prompt)) + + if not tasks: + raw_content = self._call_vllm(padded, _DEFAULT_PROMPT) + markdown_content = self._formatter._clean_content(raw_content) + else: + # Parallel OCR calls + raw_results: dict[int, str] = {} + with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex: + future_map = { + ex.submit(self._call_vllm, cropped, prompt): idx + for idx, region, cropped, prompt in tasks + } + for future in as_completed(future_map): + idx = future_map[future] + try: + raw_results[idx] = future.result() + except Exception: + raw_results[idx] = "" + + # Build structured region dicts for GLMResultFormatter + region_dicts = [] + for idx, region, _cropped, _prompt in tasks: + region_dicts.append( + { + "index": idx, + "label": region.type, + "native_label": region.native_label, + "content": raw_results.get(idx, ""), + "bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h), + } + ) + + # 4. GLM-OCR postprocessing: clean, format, merge, bullets + markdown_content = self._formatter.process(region_dicts) + + # 5. LaTeX math error correction (our existing pipeline) + markdown_content = _postprocess_markdown(markdown_content) + + # 6. Format conversion + latex, mathml, mml = "", "", "" + if markdown_content and self.converter: + fmt = self.converter.convert_to_formats(markdown_content) + latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml + + return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml} + + if __name__ == "__main__": mineru_service = MineruOCRService() image = cv2.imread("test/formula2.jpg") diff --git a/pyproject.toml b/pyproject.toml index 9c57815..13d8cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "safetensors", "lxml>=5.0.0", "openai", + "wordfreq", ] # [tool.uv.sources] diff --git a/tests/api/v1/endpoints/test_image_endpoint.py b/tests/api/v1/endpoints/test_image_endpoint.py new file mode 100644 index 0000000..5868c05 --- /dev/null +++ b/tests/api/v1/endpoints/test_image_endpoint.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app.api.v1.endpoints.image import router +from app.core.dependencies import get_glmocr_endtoend_service, get_image_processor + + +class _FakeImageProcessor: + def preprocess(self, image_url=None, image_base64=None): + return np.zeros((8, 8, 3), dtype=np.uint8) + + +class _FakeOCRService: + def __init__(self, result=None, error=None): + self._result = result or {"markdown": "md", "latex": "tex", "mathml": "mml", "mml": "xml"} + self._error = error + + def recognize(self, image): + if self._error: + raise self._error + return self._result + + +def _build_client(image_processor=None, ocr_service=None): + app = FastAPI() + app.include_router(router) + app.dependency_overrides[get_image_processor] = lambda: image_processor or _FakeImageProcessor() + app.dependency_overrides[get_glmocr_endtoend_service] = lambda: ocr_service or _FakeOCRService() + return TestClient(app) + + +def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64(): + client = _build_client() + + missing = client.post("/ocr", json={}) + both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}) + + assert missing.status_code == 422 + assert both.status_code == 422 + + +def test_image_endpoint_returns_503_for_runtime_error(): + client = _build_client(ocr_service=_FakeOCRService(error=RuntimeError("backend unavailable"))) + + response = client.post("/ocr", json={"image_url": "https://example.com/a.png"}) + + assert response.status_code == 503 + assert response.json()["detail"] == "backend unavailable" + + +def test_image_endpoint_returns_500_for_unexpected_error(): + client = _build_client(ocr_service=_FakeOCRService(error=ValueError("boom"))) + + response = client.post("/ocr", json={"image_url": "https://example.com/a.png"}) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal server error" + + +def test_image_endpoint_returns_ocr_payload(): + client = _build_client() + + response = client.post("/ocr", json={"image_base64": "ZmFrZQ=="}) + + assert response.status_code == 200 + assert response.json() == { + "latex": "tex", + "markdown": "md", + "mathml": "mml", + "mml": "xml", + "layout_info": {"regions": [], "MixedRecognition": False}, + "recognition_mode": "", + } + + +def test_image_endpoint_real_e2e_with_env_services(): + from app.main import app + + image_url = ( + "https://static.texpixel.com/formula/012dab3e-fb31-4ecd-90fc-6957458ee309.png" + "?Expires=1773049821&OSSAccessKeyId=TMP.3KnrJUz7aXHoU9rLTAih4MAyPGd9zyGRHiqg9AyH6TY6NKtzqT2yr4qo7Vwf8fMRFCBrWXiCFrbBwC3vn7U6mspV2NeU1K" + "&Signature=oynhP0OLIgFI0Sv3z2CWeHPT2Ck%3D" + ) + + with TestClient(app) as client: + response = client.post( + "/doc_process/v1/image/ocr", + json={"image_url": image_url}, + headers={"x-request-id": "test-e2e"}, + ) + + assert response.status_code == 200, response.text + payload = response.json() + assert isinstance(payload["markdown"], str) + assert payload["markdown"].strip() + assert set(payload) >= {"markdown", "latex", "mathml", "mml"} diff --git a/tests/core/test_dependencies.py b/tests/core/test_dependencies.py new file mode 100644 index 0000000..0af1474 --- /dev/null +++ b/tests/core/test_dependencies.py @@ -0,0 +1,10 @@ +import pytest + +from app.core import dependencies + + +def test_get_glmocr_endtoend_service_raises_when_layout_detector_missing(monkeypatch): + monkeypatch.setattr(dependencies, "_layout_detector", None) + + with pytest.raises(RuntimeError, match="Layout detector not initialized"): + dependencies.get_glmocr_endtoend_service() diff --git a/tests/schemas/test_image.py b/tests/schemas/test_image.py new file mode 100644 index 0000000..1613b0c --- /dev/null +++ b/tests/schemas/test_image.py @@ -0,0 +1,31 @@ +from app.schemas.image import ImageOCRRequest, LayoutRegion + + +def test_layout_region_native_label_defaults_to_empty_string(): + region = LayoutRegion( + type="text", + bbox=[0, 0, 10, 10], + confidence=0.9, + score=0.9, + ) + + assert region.native_label == "" + + +def test_layout_region_exposes_native_label_when_provided(): + region = LayoutRegion( + type="text", + native_label="doc_title", + bbox=[0, 0, 10, 10], + confidence=0.9, + score=0.9, + ) + + assert region.native_label == "doc_title" + + +def test_image_ocr_request_requires_exactly_one_input(): + request = ImageOCRRequest(image_url="https://example.com/test.png") + + assert request.image_url == "https://example.com/test.png" + assert request.image_base64 is None diff --git a/tests/services/test_glm_postprocess.py b/tests/services/test_glm_postprocess.py new file mode 100644 index 0000000..1f241bc --- /dev/null +++ b/tests/services/test_glm_postprocess.py @@ -0,0 +1,199 @@ +from app.services.glm_postprocess import ( + GLMResultFormatter, + clean_formula_number, + clean_repeated_content, + find_consecutive_repeat, +) + + +def test_find_consecutive_repeat_truncates_when_threshold_met(): + repeated = "abcdefghij" * 10 + "tail" + + assert find_consecutive_repeat(repeated) == "abcdefghij" + + +def test_find_consecutive_repeat_returns_none_when_below_threshold(): + assert find_consecutive_repeat("abcdefghij" * 9) is None + + +def test_clean_repeated_content_handles_consecutive_and_line_level_repeats(): + assert clean_repeated_content("abcdefghij" * 10 + "tail") == "abcdefghij" + + line_repeated = "\n".join(["same line"] * 10 + ["other"]) + assert clean_repeated_content(line_repeated, line_threshold=10) == "same line\n" + + assert clean_repeated_content("normal text") == "normal text" + + +def test_clean_formula_number_strips_wrapping_parentheses(): + assert clean_formula_number("(1)") == "1" + assert clean_formula_number("(2.1)") == "2.1" + assert clean_formula_number("3") == "3" + + +def test_clean_content_removes_literal_tabs_and_long_repeat_noise(): + formatter = GLMResultFormatter() + noisy = r"\t\t" + ("·" * 5) + ("abcdefghij" * 205) + r"\t" + + cleaned = formatter._clean_content(noisy) + + assert cleaned.startswith("···") + assert cleaned.endswith("abcdefghij") + assert r"\t" not in cleaned + + +def test_format_content_handles_titles_formula_text_and_newlines(): + formatter = GLMResultFormatter() + + assert formatter._format_content("Intro", "text", "doc_title") == "# Intro" + assert formatter._format_content("- Section", "text", "paragraph_title") == "## Section" + assert formatter._format_content(r"\[x+y\]", "formula", "display_formula") == "$$\nx+y\n$$" + assert formatter._format_content("· item\nnext", "text", "text") == "- item\n\nnext" + + +def test_merge_formula_numbers_merges_before_and_after_formula(): + formatter = GLMResultFormatter() + + before = formatter._merge_formula_numbers( + [ + {"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"}, + {"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, + ] + ) + after = formatter._merge_formula_numbers( + [ + {"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, + {"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"}, + ] + ) + untouched = formatter._merge_formula_numbers( + [{"index": 0, "label": "text", "native_label": "formula_number", "content": "(3)"}] + ) + + assert before == [ + { + "index": 0, + "label": "formula", + "native_label": "display_formula", + "content": "$$\nx+y \\tag{1}\n$$", + } + ] + assert after == [ + { + "index": 0, + "label": "formula", + "native_label": "display_formula", + "content": "$$\nx+y \\tag{2}\n$$", + } + ] + assert untouched == [] + + +def test_merge_text_blocks_joins_hyphenated_words_when_wordfreq_accepts(monkeypatch): + formatter = GLMResultFormatter() + + monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True) + monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 3.0) + + merged = formatter._merge_text_blocks( + [ + {"index": 0, "label": "text", "native_label": "text", "content": "inter-"}, + {"index": 1, "label": "text", "native_label": "text", "content": "national"}, + ] + ) + + assert merged == [ + {"index": 0, "label": "text", "native_label": "text", "content": "international"} + ] + + +def test_merge_text_blocks_skips_invalid_merge(monkeypatch): + formatter = GLMResultFormatter() + + monkeypatch.setattr("app.services.glm_postprocess._WORDFREQ_AVAILABLE", True) + monkeypatch.setattr("app.services.glm_postprocess.zipf_frequency", lambda word, lang: 1.0) + + merged = formatter._merge_text_blocks( + [ + {"index": 0, "label": "text", "native_label": "text", "content": "inter-"}, + {"index": 1, "label": "text", "native_label": "text", "content": "National"}, + ] + ) + + assert merged == [ + {"index": 0, "label": "text", "native_label": "text", "content": "inter-"}, + {"index": 1, "label": "text", "native_label": "text", "content": "National"}, + ] + + +def test_format_bullet_points_infers_missing_middle_bullet(): + formatter = GLMResultFormatter() + items = [ + {"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]}, + {"native_label": "text", "content": "second", "bbox_2d": [12, 12, 52, 22]}, + {"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]}, + ] + + formatted = formatter._format_bullet_points(items) + + assert formatted[1]["content"] == "- second" + + +def test_format_bullet_points_skips_when_bbox_missing(): + formatter = GLMResultFormatter() + items = [ + {"native_label": "text", "content": "- first", "bbox_2d": [10, 0, 50, 10]}, + {"native_label": "text", "content": "second", "bbox_2d": []}, + {"native_label": "text", "content": "- third", "bbox_2d": [11, 24, 51, 34]}, + ] + + formatted = formatter._format_bullet_points(items) + + assert formatted[1]["content"] == "second" + + +def test_process_runs_full_pipeline_and_skips_empty_content(): + formatter = GLMResultFormatter() + regions = [ + { + "index": 0, + "label": "text", + "native_label": "doc_title", + "content": "Doc Title", + "bbox_2d": [0, 0, 100, 30], + }, + { + "index": 1, + "label": "text", + "native_label": "formula_number", + "content": "(1)", + "bbox_2d": [80, 50, 100, 60], + }, + { + "index": 2, + "label": "formula", + "native_label": "display_formula", + "content": "x+y", + "bbox_2d": [0, 40, 100, 80], + }, + { + "index": 3, + "label": "figure", + "native_label": "image", + "content": "figure placeholder", + "bbox_2d": [0, 80, 100, 120], + }, + { + "index": 4, + "label": "text", + "native_label": "text", + "content": "", + "bbox_2d": [0, 120, 100, 150], + }, + ] + + output = formatter.process(regions) + + assert "# Doc Title" in output + assert "$$\nx+y \\tag{1}\n$$" in output + assert "![](bbox=[0, 80, 100, 120])" in output diff --git a/tests/services/test_layout_detector.py b/tests/services/test_layout_detector.py new file mode 100644 index 0000000..db8584a --- /dev/null +++ b/tests/services/test_layout_detector.py @@ -0,0 +1,46 @@ +import numpy as np + +from app.services.layout_detector import LayoutDetector + + +class _FakePredictor: + def __init__(self, boxes): + self._boxes = boxes + + def predict(self, image): + return [{"boxes": self._boxes}] + + +def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch): + raw_boxes = [ + {"cls_id": 22, "label": "text", "score": 0.95, "coordinate": [0, 0, 100, 100]}, + {"cls_id": 22, "label": "text", "score": 0.90, "coordinate": [10, 10, 20, 20]}, + {"cls_id": 6, "label": "doc_title", "score": 0.99, "coordinate": [0, 0, 80, 20]}, + ] + + detector = LayoutDetector.__new__(LayoutDetector) + detector._get_layout_detector = lambda: _FakePredictor(raw_boxes) + + calls = {} + + def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode): + calls["args"] = { + "boxes": boxes, + "img_size": img_size, + "layout_nms": layout_nms, + "layout_unclip_ratio": layout_unclip_ratio, + "layout_merge_bboxes_mode": layout_merge_bboxes_mode, + } + return [boxes[0], boxes[2]] + + monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess) + + image = np.zeros((200, 100, 3), dtype=np.uint8) + info = detector.detect(image) + + assert calls["args"]["img_size"] == (100, 200) + assert calls["args"]["layout_nms"] is True + assert calls["args"]["layout_merge_bboxes_mode"] == "large" + assert [region.native_label for region in info.regions] == ["text", "doc_title"] + assert [region.type for region in info.regions] == ["text", "text"] + assert info.MixedRecognition is True diff --git a/tests/services/test_layout_postprocess.py b/tests/services/test_layout_postprocess.py new file mode 100644 index 0000000..be32f29 --- /dev/null +++ b/tests/services/test_layout_postprocess.py @@ -0,0 +1,151 @@ +import math + +import numpy as np + +from app.services.layout_postprocess import ( + apply_layout_postprocess, + check_containment, + iou, + is_contained, + nms, + unclip_boxes, +) + + +def _raw_box(cls_id, score, x1, y1, x2, y2, label="text"): + return { + "cls_id": cls_id, + "label": label, + "score": score, + "coordinate": [x1, y1, x2, y2], + } + + +def test_iou_handles_full_none_and_partial_overlap(): + assert iou([0, 0, 9, 9], [0, 0, 9, 9]) == 1.0 + assert iou([0, 0, 9, 9], [20, 20, 29, 29]) == 0.0 + assert math.isclose(iou([0, 0, 9, 9], [5, 5, 14, 14]), 1 / 7, rel_tol=1e-6) + + +def test_nms_keeps_highest_score_for_same_class_overlap(): + boxes = np.array( + [ + [0, 0.95, 0, 0, 10, 10], + [0, 0.80, 1, 1, 11, 11], + ], + dtype=float, + ) + + kept = nms(boxes, iou_same=0.6, iou_diff=0.98) + + assert kept == [0] + + +def test_nms_keeps_cross_class_overlap_boxes_below_diff_threshold(): + boxes = np.array( + [ + [0, 0.95, 0, 0, 10, 10], + [1, 0.90, 1, 1, 11, 11], + ], + dtype=float, + ) + + kept = nms(boxes, iou_same=0.6, iou_diff=0.98) + + assert kept == [0, 1] + + +def test_nms_returns_single_box_index(): + boxes = np.array([[0, 0.95, 0, 0, 10, 10]], dtype=float) + + assert nms(boxes) == [0] + + +def test_is_contained_uses_overlap_threshold(): + outer = [0, 0.9, 0, 0, 10, 10] + inner = [0, 0.9, 2, 2, 8, 8] + partial = [0, 0.9, 6, 6, 12, 12] + + assert is_contained(inner, outer) is True + assert is_contained(partial, outer) is False + assert is_contained(partial, outer, overlap_threshold=0.3) is True + + +def test_check_containment_respects_preserve_class_ids(): + boxes = np.array( + [ + [0, 0.9, 0, 0, 100, 100], + [1, 0.8, 10, 10, 30, 30], + [2, 0.7, 15, 15, 25, 25], + ], + dtype=float, + ) + + contains_other, contained_by_other = check_containment(boxes, preserve_cls_ids={1}) + + assert contains_other.tolist() == [1, 1, 0] + assert contained_by_other.tolist() == [0, 0, 1] + + +def test_unclip_boxes_supports_scalar_tuple_dict_and_none(): + boxes = np.array( + [ + [0, 0.9, 10, 10, 20, 20], + [1, 0.8, 30, 30, 50, 40], + ], + dtype=float, + ) + + scalar = unclip_boxes(boxes, 2.0) + assert scalar[:, 2:6].tolist() == [[5.0, 5.0, 25.0, 25.0], [20.0, 25.0, 60.0, 45.0]] + + tuple_ratio = unclip_boxes(boxes, (2.0, 3.0)) + assert tuple_ratio[:, 2:6].tolist() == [[5.0, 0.0, 25.0, 30.0], [20.0, 20.0, 60.0, 50.0]] + + per_class = unclip_boxes(boxes, {1: (1.5, 2.0)}) + assert per_class[:, 2:6].tolist() == [[10.0, 10.0, 20.0, 20.0], [25.0, 25.0, 55.0, 45.0]] + + assert np.array_equal(unclip_boxes(boxes, None), boxes) + + +def test_apply_layout_postprocess_large_mode_removes_contained_small_box(): + boxes = [ + _raw_box(0, 0.95, 0, 0, 100, 100, "text"), + _raw_box(0, 0.90, 10, 10, 20, 20, "text"), + ] + + result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large") + + assert [box["coordinate"] for box in result] == [[0, 0, 100, 100]] + + +def test_apply_layout_postprocess_preserves_contained_image_like_boxes(): + boxes = [ + _raw_box(0, 0.95, 0, 0, 100, 100, "text"), + _raw_box(1, 0.90, 10, 10, 20, 20, "image"), + _raw_box(2, 0.90, 25, 25, 35, 35, "seal"), + _raw_box(3, 0.90, 40, 40, 50, 50, "chart"), + ] + + result = apply_layout_postprocess(boxes, img_size=(120, 120), layout_merge_bboxes_mode="large") + + assert {box["label"] for box in result} == {"text", "image", "seal", "chart"} + + +def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image(): + boxes = [ + _raw_box(0, 0.95, -10, -5, 40, 50, "text"), + _raw_box(1, 0.90, 10, 10, 10, 50, "text"), + _raw_box(2, 0.85, 0, 0, 100, 90, "image"), + ] + + result = apply_layout_postprocess( + boxes, + img_size=(100, 90), + layout_nms=False, + layout_merge_bboxes_mode=None, + ) + + assert result == [ + {"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]} + ] diff --git a/tests/services/test_ocr_service.py b/tests/services/test_ocr_service.py new file mode 100644 index 0000000..d57b451 --- /dev/null +++ b/tests/services/test_ocr_service.py @@ -0,0 +1,124 @@ +import base64 +from types import SimpleNamespace + +import cv2 +import numpy as np + +from app.schemas.image import LayoutInfo, LayoutRegion +from app.services.ocr_service import GLMOCREndToEndService + + +class _FakeConverter: + def convert_to_formats(self, markdown): + return SimpleNamespace( + latex=f"LATEX::{markdown}", + mathml=f"MATHML::{markdown}", + mml=f"MML::{markdown}", + ) + + +class _FakeImageProcessor: + def add_padding(self, image): + return image + + +class _FakeLayoutDetector: + def __init__(self, regions): + self._regions = regions + + def detect(self, image): + return LayoutInfo(regions=self._regions, MixedRecognition=bool(self._regions)) + + +def _build_service(regions=None): + return GLMOCREndToEndService( + vl_server_url="http://127.0.0.1:8002/v1", + image_processor=_FakeImageProcessor(), + converter=_FakeConverter(), + layout_detector=_FakeLayoutDetector(regions or []), + max_workers=2, + ) + + +def test_encode_region_returns_decodable_base64_jpeg(): + service = _build_service() + image = np.zeros((8, 12, 3), dtype=np.uint8) + image[:, :] = [0, 128, 255] + + encoded = service._encode_region(image) + decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR) + + assert decoded.shape[:2] == image.shape[:2] + + +def test_call_vllm_builds_messages_and_returns_content(): + service = _build_service() + captured = {} + + def create(**kwargs): + captured.update(kwargs) + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content=" recognized content \n"))] + ) + + service.openai_client = SimpleNamespace( + chat=SimpleNamespace(completions=SimpleNamespace(create=create)) + ) + + result = service._call_vllm(np.zeros((4, 4, 3), dtype=np.uint8), "Formula Recognition:") + + assert result == "recognized content" + assert captured["model"] == "glm-ocr" + assert captured["max_tokens"] == 1024 + assert captured["messages"][0]["content"][0]["type"] == "image_url" + assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,") + assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"} + + +def test_normalize_bbox_scales_coordinates_to_1000(): + service = _build_service() + + assert service._normalize_bbox([0, 0, 200, 100], 200, 100) == [0, 0, 1000, 1000] + assert service._normalize_bbox([50, 25, 150, 75], 200, 100) == [250, 250, 750, 750] + + +def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch): + service = _build_service(regions=[]) + image = np.zeros((20, 30, 3), dtype=np.uint8) + + monkeypatch.setattr(service, "_call_vllm", lambda image, prompt: "raw text") + + result = service.recognize(image) + + assert result["markdown"] == "raw text" + assert result["latex"] == "LATEX::raw text" + assert result["mathml"] == "MATHML::raw text" + assert result["mml"] == "MML::raw text" + + +def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch): + regions = [ + LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9), + LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8), + LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95), + ] + service = _build_service(regions=regions) + image = np.zeros((40, 40, 3), dtype=np.uint8) + + calls = [] + + def fake_call_vllm(cropped, prompt): + calls.append(prompt) + if prompt == "Text Recognition:": + return "Title" + return "x + y" + + monkeypatch.setattr(service, "_call_vllm", fake_call_vllm) + + result = service.recognize(image) + + assert calls == ["Text Recognition:", "Formula Recognition:"] + assert result["markdown"] == "# Title\n\n$$\nx + y\n$$" + assert result["latex"] == "LATEX::# Title\n\n$$\nx + y\n$$" + assert result["mathml"] == "MATHML::# Title\n\n$$\nx + y\n$$" + assert result["mml"] == "MML::# Title\n\n$$\nx + y\n$$"