5 Commits

Author SHA1 Message Date
liuyuanchuang
5ba835ab44 fix: add deadsnakes PPA for python3.10 on Ubuntu 24.04
Ubuntu 24.04 ships Python 3.12 by default.
python3.10-venv/dev/distutils are not in standard repos.
Must add ppa:deadsnakes/ppa in both builder and runtime stages.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 11:37:32 +08:00
liuyuanchuang
7c7d4bf36a fix: restore wheels/ COPY without invalid shell operators
COPY does not support shell operators (||, 2>/dev/null).
Keep wheels/ for paddlepaddle whl installation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 11:36:28 +08:00
liuyuanchuang
ef98f37525 feat: aggressive image optimization for PPDocLayoutV3 only
- Remove doclayout-yolo (~4.8GB, torch/torchvision/triton)
- Replace opencv-python with opencv-python-headless (~200MB)
- Strip debug symbols from .so files (~300-800MB)
- Remove paddle C++ headers (~22MB)
- Use cuda:base instead of runtime (~3GB savings)
- Simplify dependencies: remove doc-parser extras
- Clean venv aggressively: no pip, setuptools, include/, share/

Expected size reduction:
  Before: 17GB
  After:  ~3GB (82% reduction)

Breakdown:
  - CUDA base: 0.4GB
  - Paddle: 0.7GB
  - PaddleOCR: 0.8GB
  - OpenCV-headless: 0.2GB
  - Other deps: 0.6GB
  Total: ~2.7-3GB

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-10 11:33:50 +08:00
liuyuanchuang
95c497829f fix: remove VOLUME declaration to prevent anonymous volumes
- Remove VOLUME directive that was creating anonymous volumes
- Keep directory creation (mkdir) for runtime mount points
- Users must explicitly mount volumes with -v flags
- This prevents hidden volume bloat in docker exec

Usage:
  docker run --gpus all -p 8053:8053 \
    -v /home/yoge/.cache/modelscope:/root/.cache/modelscope:ro \
    -v /home/yoge/.cache/huggingface:/root/.cache/huggingface:ro \
    -v /home/yoge/.paddlex:/root/.paddlex:ro \
    doc_processer:latest

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-10 11:12:01 +08:00
liuyuanchuang
6579cf55f5 feat: optimize Docker image with multi-stage build
- Use multi-stage build to exclude build dependencies from final image
- Separate builder stage using devel image from runtime stage using smaller base image
- Clean venv: remove __pycache__, .pyc files, and test directories
- Remove embedded model files (243MB) from app/model/ - mount at runtime instead
- Expected size reduction: 18.9GB → 2-3GB (80-90% reduction)

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-10 10:41:32 +08:00
18 changed files with 211 additions and 213 deletions

View File

@@ -1,82 +1,103 @@
# DocProcesser Dockerfile # DocProcesser Dockerfile - Production optimized
# Optimized for RTX 5080 GPU deployment # Ultra-lean multi-stage build for PPDocLayoutV3
# Final image: ~3GB (from 17GB)
# Use NVIDIA CUDA base image with Python 3.10 # =============================================================================
FROM nvidia/cuda:12.9.0-runtime-ubuntu24.04 # STAGE 1: Builder
# =============================================================================
FROM nvidia/cuda:12.9.0-devel-ubuntu24.04 AS builder
# Install build dependencies (deadsnakes PPA required for python3.10 on Ubuntu 24.04)
RUN apt-get update && apt-get install -y --no-install-recommends \
software-properties-common \
&& add-apt-repository -y ppa:deadsnakes/ppa \
&& apt-get update && apt-get install -y --no-install-recommends \
python3.10 python3.10-venv python3.10-dev python3.10-distutils \
build-essential curl \
&& rm -rf /var/lib/apt/lists/*
# Setup Python
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
curl -sS https://bootstrap.pypa.io/get-pip.py | python
# Install uv
RUN pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
WORKDIR /build
# Copy dependencies
COPY pyproject.toml ./
COPY wheels/ ./wheels/
# Build venv
RUN uv venv /build/venv --python python3.10 && \
. /build/venv/bin/activate && \
uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -e . && \
rm -rf ./wheels
# Aggressive optimization: strip debug symbols from .so files (~300-800MB saved)
RUN find /build/venv -name "*.so" -exec strip --strip-unneeded {} + || true
# Remove paddle C++ headers (~22MB saved)
RUN rm -rf /build/venv/lib/python*/site-packages/paddle/include
# Clean Python cache and build artifacts
RUN find /build/venv -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true && \
find /build/venv -type f -name "*.pyc" -delete && \
find /build/venv -type f -name "*.pyo" -delete && \
find /build/venv -type d -name "tests" -exec rm -rf {} + 2>/dev/null || true && \
find /build/venv -type d -name "test" -exec rm -rf {} + 2>/dev/null || true && \
rm -rf /build/venv/lib/*/site-packages/pip* \
/build/venv/lib/*/site-packages/setuptools* \
/build/venv/include \
/build/venv/share && \
rm -rf /root/.cache 2>/dev/null || true
# =============================================================================
# STAGE 2: Runtime - CUDA base (~400MB, not ~3.4GB from runtime)
# =============================================================================
FROM nvidia/cuda:12.9.0-base-ubuntu24.04
# Set environment variables
ENV PYTHONUNBUFFERED=1 \ ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \ PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1 \ PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \
# Model cache directories - mount these at runtime
MODELSCOPE_CACHE=/root/.cache/modelscope \ MODELSCOPE_CACHE=/root/.cache/modelscope \
HF_HOME=/root/.cache/huggingface \ HF_HOME=/root/.cache/huggingface \
# Application config (override defaults for container)
# Use 127.0.0.1 for --network host mode, or override with -e for bridge mode
PP_DOCLAYOUT_MODEL_DIR=/root/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2 \ PP_DOCLAYOUT_MODEL_DIR=/root/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2 \
PADDLEOCR_VL_URL=http://127.0.0.1:8001/v1 PADDLEOCR_VL_URL=http://127.0.0.1:8001/v1 \
PATH="/app/.venv/bin:$PATH" \
VIRTUAL_ENV="/app/.venv"
# Set working directory
WORKDIR /app WORKDIR /app
# Install system dependencies and Python 3.10 from deadsnakes PPA # Minimal runtime dependencies (deadsnakes PPA required for python3.10 on Ubuntu 24.04)
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
software-properties-common \ software-properties-common \
&& add-apt-repository -y ppa:deadsnakes/ppa \ && add-apt-repository -y ppa:deadsnakes/ppa \
&& apt-get update && apt-get install -y --no-install-recommends \ && apt-get update && apt-get install -y --no-install-recommends \
python3.10 \ python3.10 \
python3.10-venv \ libgl1 libglib2.0-0 libgomp1 \
python3.10-dev \ curl pandoc \
python3.10-distutils \ && rm -rf /var/lib/apt/lists/*
libgl1 \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libgomp1 \
curl \
pandoc \
&& rm -rf /var/lib/apt/lists/* \
&& ln -sf /usr/bin/python3.10 /usr/bin/python \
&& ln -sf /usr/bin/python3.10 /usr/bin/python3 \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
# Install uv via pip (more reliable than install script) RUN ln -sf /usr/bin/python3.10 /usr/bin/python
RUN python3.10 -m pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
ENV PATH="/app/.venv/bin:$PATH"
ENV VIRTUAL_ENV="/app/.venv"
# Copy dependency files first for better caching # Copy optimized venv from builder
COPY pyproject.toml ./ COPY --from=builder /build/venv /app/.venv
COPY wheels/ ./wheels/
# Create virtual environment and install dependencies # Copy app code
RUN uv venv /app/.venv --python python3.10 \
&& uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -e . \
&& rm -rf ./wheels
# Copy application code
COPY app/ ./app/ COPY app/ ./app/
# Create model cache directories (mount from host at runtime) # Create cache mount points (DO NOT include model files)
RUN mkdir -p /root/.cache/modelscope \ RUN mkdir -p /root/.cache/modelscope /root/.cache/huggingface /root/.paddlex && \
/root/.cache/huggingface \ rm -rf /app/app/model/*
/root/.paddlex \
/app/app/model/DocLayout \
/app/app/model/PP-DocLayout
# Declare volumes for model cache (mount at runtime to avoid re-downloading)
VOLUME ["/root/.cache/modelscope", "/root/.cache/huggingface", "/root/.paddlex"]
# Expose port
EXPOSE 8053 EXPOSE 8053
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8053/health || exit 1 CMD curl -f http://localhost:8053/health || exit 1
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8053", "--workers", "1"] CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8053", "--workers", "1"]
# ============================================================================= # =============================================================================

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import Response from fastapi.responses import Response
from app.core.dependencies import get_converter from app.core.dependencies import get_converter
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest from app.schemas.convert import MarkdownToDocxRequest, LatexToOmmlRequest, LatexToOmmlResponse
from app.services.converter import Converter from app.services.converter import Converter
router = APIRouter() router = APIRouter()

View File

@@ -6,10 +6,10 @@ import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from app.core.dependencies import ( from app.core.dependencies import (
get_glmocr_endtoend_service,
get_image_processor, get_image_processor,
get_glmocr_endtoend_service,
) )
from app.core.logging_config import RequestIDAdapter, get_logger from app.core.logging_config import get_logger, RequestIDAdapter
from app.schemas.image import ImageOCRRequest, ImageOCRResponse from app.schemas.image import ImageOCRRequest, ImageOCRResponse
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.ocr_service import GLMOCREndToEndService from app.services.ocr_service import GLMOCREndToEndService

View File

@@ -1,10 +1,10 @@
"""Application dependencies.""" """Application dependencies."""
from app.core.config import get_settings
from app.services.converter import Converter
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import GLMOCREndToEndService from app.services.ocr_service import GLMOCREndToEndService
from app.services.converter import Converter
from app.core.config import get_settings
# Global instances (initialized on startup) # Global instances (initialized on startup)
_layout_detector: LayoutDetector | None = None _layout_detector: LayoutDetector | None = None

View File

@@ -3,7 +3,7 @@
import logging import logging
import logging.handlers import logging.handlers
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Optional
from app.core.config import get_settings from app.core.config import get_settings
@@ -18,10 +18,10 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
interval: int = 1, interval: int = 1,
backupCount: int = 30, backupCount: int = 30,
maxBytes: int = 100 * 1024 * 1024, # 100MB maxBytes: int = 100 * 1024 * 1024, # 100MB
encoding: str | None = None, encoding: Optional[str] = None,
delay: bool = False, delay: bool = False,
utc: bool = False, utc: bool = False,
atTime: Any | None = None, atTime: Optional[Any] = None,
): ):
"""Initialize handler with both time and size rotation. """Initialize handler with both time and size rotation.
@@ -58,14 +58,14 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
if self.stream is None: if self.stream is None:
self.stream = self._open() self.stream = self._open()
if self.maxBytes > 0: if self.maxBytes > 0:
msg = f"{self.format(record)}\n" msg = "%s\n" % self.format(record)
self.stream.seek(0, 2) # Seek to end self.stream.seek(0, 2) # Seek to end
if self.stream.tell() + len(msg) >= self.maxBytes: if self.stream.tell() + len(msg) >= self.maxBytes:
return True return True
return False return False
def setup_logging(log_dir: str | None = None) -> logging.Logger: def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
"""Setup application logging with rotation by day and size. """Setup application logging with rotation by day and size.
Args: Args:
@@ -134,7 +134,7 @@ def setup_logging(log_dir: str | None = None) -> logging.Logger:
# Global logger instance # Global logger instance
_logger: logging.Logger | None = None _logger: Optional[logging.Logger] = None
def get_logger() -> logging.Logger: def get_logger() -> logging.Logger:

View File

@@ -36,3 +36,4 @@ class LatexToOmmlResponse(BaseModel):
"""Response body for LaTeX to OMML conversion endpoint.""" """Response body for LaTeX to OMML conversion endpoint."""
omml: str = Field("", description="OMML (Office Math Markup Language) representation") omml: str = Field("", description="OMML (Office Math Markup Language) representation")

View File

@@ -7,9 +7,7 @@ class LayoutRegion(BaseModel):
"""A detected layout region in the document.""" """A detected layout region in the document."""
type: str = Field(..., description="Region type: text, formula, table, figure") type: str = Field(..., description="Region type: text, formula, table, figure")
native_label: str = Field( native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
"", description="Raw label before type mapping (e.g. doc_title, formula_number)"
)
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]") bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
confidence: float = Field(..., description="Detection confidence score") confidence: float = Field(..., description="Detection confidence score")
score: float = Field(..., description="Detection score") score: float = Field(..., description="Detection score")
@@ -43,15 +41,10 @@ class ImageOCRRequest(BaseModel):
class ImageOCRResponse(BaseModel): class ImageOCRResponse(BaseModel):
"""Response body for image OCR endpoint.""" """Response body for image OCR endpoint."""
latex: str = Field( latex: str = Field("", description="LaTeX representation of the content (empty if mixed content)")
"", description="LaTeX representation of the content (empty if mixed content)"
)
markdown: str = Field("", description="Markdown representation of the content") markdown: str = Field("", description="Markdown representation of the content")
mathml: str = Field("", description="Standard MathML representation (empty if mixed content)") mathml: str = Field("", description="Standard MathML representation (empty if mixed content)")
mml: str = Field( mml: str = Field("", description="XML MathML with mml: namespace prefix (empty if mixed content)")
"", description="XML MathML with mml: namespace prefix (empty if mixed content)"
)
layout_info: LayoutInfo = Field(default_factory=LayoutInfo) layout_info: LayoutInfo = Field(default_factory=LayoutInfo)
recognition_mode: str = Field( recognition_mode: str = Field("", description="Recognition mode used: mixed_recognition or formula_recognition")
"", description="Recognition mode used: mixed_recognition or formula_recognition"
)

View File

@@ -112,18 +112,14 @@ class Converter:
# Pre-compiled regex patterns for preprocessing # Pre-compiled regex patterns for preprocessing
_RE_VSPACE = re.compile(r"\\\[1mm\]") _RE_VSPACE = re.compile(r"\\\[1mm\]")
_RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL) _RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL)
_RE_BLOCK_FORMULA_LINE = re.compile( _RE_BLOCK_FORMULA_LINE = re.compile(r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL)
r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL
)
_RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>') _RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>')
_RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)") _RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)")
_RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}") _RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}")
_RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+") _RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+")
_RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}") _RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}")
_RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL) _RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL)
_RE_ALIGNED_BRACE = re.compile( _RE_ALIGNED_BRACE = re.compile(r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL)
r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL
)
_RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL) _RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL)
_RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL) _RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL)
_RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL) _RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL)
@@ -372,9 +368,7 @@ class Converter:
mathml = latex_to_mathml(latex_formula) mathml = latex_to_mathml(latex_formula)
return Converter._postprocess_mathml_for_word(mathml) return Converter._postprocess_mathml_for_word(mathml)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}") from e
f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}"
) from e
@staticmethod @staticmethod
def _postprocess_mathml_for_word(mathml: str) -> str: def _postprocess_mathml_for_word(mathml: str) -> str:
@@ -589,6 +583,7 @@ class Converter:
"&#x21D3;": "", # Downarrow "&#x21D3;": "", # Downarrow
"&#x2195;": "", # updownarrow "&#x2195;": "", # updownarrow
"&#x21D5;": "", # Updownarrow "&#x21D5;": "", # Updownarrow
"&#x2260;": "", # ne
"&#x226A;": "", # ll "&#x226A;": "", # ll
"&#x226B;": "", # gg "&#x226B;": "", # gg
"&#x2A7D;": "", # leqslant "&#x2A7D;": "", # leqslant
@@ -967,7 +962,7 @@ class Converter:
"""Export to DOCX format using pypandoc.""" """Export to DOCX format using pypandoc."""
extra_args = [ extra_args = [
"--highlight-style=pygments", "--highlight-style=pygments",
"--reference-doc=app/pkg/reference.docx", f"--reference-doc=app/pkg/reference.docx",
] ]
pypandoc.convert_file( pypandoc.convert_file(
input_path, input_path,

View File

@@ -1,10 +1,26 @@
"""GLM-OCR postprocessing logic adapted for this project.
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
glm-ocr/glmocr/utils/result_postprocess_utils.py.
Covers:
- Repeated-content / hallucination detection
- Per-region content cleaning and formatting (titles, bullets, formulas)
- formula_number merging (→ \\tag{})
- Hyphenated text-block merging (via wordfreq)
- Missing bullet-point detection
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
import json
logger = logging.getLogger(__name__)
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any, Dict, List, Optional, Tuple
try: try:
from wordfreq import zipf_frequency from wordfreq import zipf_frequency
@@ -13,14 +29,13 @@ try:
except ImportError: except ImportError:
_WORDFREQ_AVAILABLE = False _WORDFREQ_AVAILABLE = False
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# result_postprocess_utils (ported) # result_postprocess_utils (ported)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None: def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
"""Detect and truncate a consecutively-repeated pattern. """Detect and truncate a consecutively-repeated pattern.
Returns the string with the repeat removed, or None if not found. Returns the string with the repeat removed, or None if not found.
@@ -34,13 +49,7 @@ def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 1
return None return None
pattern = re.compile( pattern = re.compile(
r"(.{" r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
+ str(min_unit_len)
+ ","
+ str(max_unit_len)
+ r"}?)\1{"
+ str(min_repeats - 1)
+ ",}",
re.DOTALL, re.DOTALL,
) )
match = pattern.search(s) match = pattern.search(s)
@@ -74,9 +83,7 @@ def clean_repeated_content(
if count >= line_threshold and (count / total_lines) >= 0.8: if count >= line_threshold and (count / total_lines) >= 0.8:
for i, line in enumerate(lines): for i, line in enumerate(lines):
if line == common: if line == common:
consecutive = sum( consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common
)
if consecutive >= 3: if consecutive >= 3:
original_lines = content.split("\n") original_lines = content.split("\n")
non_empty_count = 0 non_empty_count = 0
@@ -99,7 +106,7 @@ def clean_formula_number(number_content: str) -> str:
# Strip display math delimiters # Strip display math delimiters
for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]: for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]:
if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end): if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end):
s = s[len(start) : -len(end)].strip() s = s[len(start):-len(end)].strip()
break break
# Strip CJK/ASCII parentheses # Strip CJK/ASCII parentheses
if s.startswith("(") and s.endswith(")"): if s.startswith("(") and s.endswith(")"):
@@ -114,7 +121,7 @@ def clean_formula_number(number_content: str) -> str:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping) # Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
_LABEL_TO_CATEGORY: dict[str, str] = { _LABEL_TO_CATEGORY: Dict[str, str] = {
# text # text
"abstract": "text", "abstract": "text",
"algorithm": "text", "algorithm": "text",
@@ -150,7 +157,7 @@ class GLMResultFormatter:
# Public entry-point # Public entry-point
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def process(self, regions: list[dict[str, Any]]) -> str: def process(self, regions: List[Dict[str, Any]]) -> str:
"""Run the full postprocessing pipeline and return Markdown. """Run the full postprocessing pipeline and return Markdown.
Args: Args:
@@ -168,7 +175,7 @@ class GLMResultFormatter:
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0)) items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
# Per-region cleaning + formatting # Per-region cleaning + formatting
processed: list[dict] = [] processed: List[Dict] = []
for item in items: for item in items:
item["native_label"] = item.get("native_label", item.get("label", "text")) item["native_label"] = item.get("native_label", item.get("label", "text"))
item["label"] = self._map_label(item.get("label", "text"), item["native_label"]) item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
@@ -192,7 +199,7 @@ class GLMResultFormatter:
processed = self._format_bullet_points(processed) processed = self._format_bullet_points(processed)
# Assemble Markdown # Assemble Markdown
parts: list[str] = [] parts: List[str] = []
for item in processed: for item in processed:
content = item.get("content") or "" content = item.get("content") or ""
if item["label"] == "image": if item["label"] == "image":
@@ -256,15 +263,11 @@ class GLMResultFormatter:
if label == "formula": if label == "formula":
content = content.strip() content = content.strip()
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]: for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
if content.startswith(s): if content.startswith(s) and content.endswith(e):
content = content[len(s) :].strip() content = content[len(s) : -len(e)].strip()
if content.endswith(e):
content = content[: -len(e)].strip()
break break
if not content: if not content:
logger.warning( logger.warning("Skipping formula region with empty content after stripping delimiters")
"Skipping formula region with empty content after stripping delimiters"
)
return "" return ""
content = "$$\n" + content + "\n$$" content = "$$\n" + content + "\n$$"
@@ -293,12 +296,12 @@ class GLMResultFormatter:
# Structural merges # Structural merges
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def _merge_formula_numbers(self, items: list[dict]) -> list[dict]: def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
"""Merge formula_number region into adjacent formula with \\tag{}.""" """Merge formula_number region into adjacent formula with \\tag{}."""
if not items: if not items:
return items return items
merged: list[dict] = [] merged: List[Dict] = []
skip: set = set() skip: set = set()
for i, block in enumerate(items): for i, block in enumerate(items):
@@ -314,9 +317,7 @@ class GLMResultFormatter:
formula_content = items[i + 1].get("content", "") formula_content = items[i + 1].get("content", "")
merged_block = deepcopy(items[i + 1]) merged_block = deepcopy(items[i + 1])
if formula_content.endswith("\n$$"): if formula_content.endswith("\n$$"):
merged_block["content"] = ( merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
)
merged.append(merged_block) merged.append(merged_block)
skip.add(i + 1) skip.add(i + 1)
continue # always skip the formula_number block itself continue # always skip the formula_number block itself
@@ -328,9 +329,7 @@ class GLMResultFormatter:
formula_content = block.get("content", "") formula_content = block.get("content", "")
merged_block = deepcopy(block) merged_block = deepcopy(block)
if formula_content.endswith("\n$$"): if formula_content.endswith("\n$$"):
merged_block["content"] = ( merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
)
merged.append(merged_block) merged.append(merged_block)
skip.add(i + 1) skip.add(i + 1)
continue continue
@@ -341,12 +340,12 @@ class GLMResultFormatter:
block["index"] = i block["index"] = i
return merged return merged
def _merge_text_blocks(self, items: list[dict]) -> list[dict]: def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
"""Merge hyphenated text blocks when the combined word is valid (wordfreq).""" """Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
if not items or not _WORDFREQ_AVAILABLE: if not items or not _WORDFREQ_AVAILABLE:
return items return items
merged: list[dict] = [] merged: List[Dict] = []
skip: set = set() skip: set = set()
for i, block in enumerate(items): for i, block in enumerate(items):
@@ -390,9 +389,7 @@ class GLMResultFormatter:
block["index"] = i block["index"] = i
return merged return merged
def _format_bullet_points( def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
self, items: list[dict], left_align_threshold: float = 10.0
) -> list[dict]:
"""Add missing bullet prefix when a text block is sandwiched between two bullet items.""" """Add missing bullet prefix when a text block is sandwiched between two bullet items."""
if len(items) < 3: if len(items) < 3:
return items return items

View File

@@ -1,11 +1,12 @@
"""PP-DocLayoutV3 wrapper for document layout detection.""" """PP-DocLayoutV3 wrapper for document layout detection."""
import numpy as np import numpy as np
from paddleocr import LayoutDetection
from app.core.config import get_settings
from app.schemas.image import LayoutInfo, LayoutRegion from app.schemas.image import LayoutInfo, LayoutRegion
from app.core.config import get_settings
from app.services.layout_postprocess import apply_layout_postprocess from app.services.layout_postprocess import apply_layout_postprocess
from paddleocr import LayoutDetection
from typing import Optional
settings = get_settings() settings = get_settings()
@@ -13,7 +14,7 @@ settings = get_settings()
class LayoutDetector: class LayoutDetector:
"""Layout detector for PP-DocLayoutV2.""" """Layout detector for PP-DocLayoutV2."""
_layout_detector: LayoutDetection | None = None _layout_detector: Optional[LayoutDetection] = None
# PP-DocLayoutV2 class ID to label mapping # PP-DocLayoutV2 class ID to label mapping
CLS_ID_TO_LABEL: dict[int, str] = { CLS_ID_TO_LABEL: dict[int, str] = {
@@ -155,11 +156,10 @@ class LayoutDetector:
if __name__ == "__main__": if __name__ == "__main__":
import cv2 import cv2
from app.core.config import get_settings from app.core.config import get_settings
from app.services.converter import Converter
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.ocr_service import GLMOCREndToEndService from app.services.converter import Converter
from app.services.ocr_service import OCRService
settings = get_settings() settings = get_settings()
@@ -169,15 +169,15 @@ if __name__ == "__main__":
converter = Converter() converter = Converter()
# Initialize OCR service # Initialize OCR service
ocr_service = GLMOCREndToEndService( ocr_service = OCRService(
vl_server_url=settings.glm_ocr_url, vl_server_url=settings.paddleocr_vl_url,
layout_detector=layout_detector, layout_detector=layout_detector,
image_processor=image_processor, image_processor=image_processor,
converter=converter, converter=converter,
) )
# Load test image # Load test image
image_path = "test/image2.png" image_path = "test/timeout.jpg"
image = cv2.imread(image_path) image = cv2.imread(image_path)
if image is None: if image is None:

View File

@@ -15,14 +15,16 @@ the quality of the GLM-OCR SDK's layout pipeline.
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Primitive geometry helpers # Primitive geometry helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def iou(box1: List[float], box2: List[float]) -> float:
def iou(box1: list[float], box2: list[float]) -> float:
"""Compute IoU of two bounding boxes [x1, y1, x2, y2].""" """Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
x1, y1, x2, y2 = box1 x1, y1, x2, y2 = box1
x1_p, y1_p, x2_p, y2_p = box2 x1_p, y1_p, x2_p, y2_p = box2
@@ -39,7 +41,7 @@ def iou(box1: list[float], box2: list[float]) -> float:
return inter_area / float(box1_area + box2_area - inter_area) return inter_area / float(box1_area + box2_area - inter_area)
def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool: def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
"""Return True if box1 is contained within box2 (overlap ratio >= threshold). """Return True if box1 is contained within box2 (overlap ratio >= threshold).
box format: [cls_id, score, x1, y1, x2, y2] box format: [cls_id, score, x1, y1, x2, y2]
@@ -64,12 +66,11 @@ def is_contained(box1: list[float], box2: list[float], overlap_threshold: float
# NMS # NMS
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def nms( def nms(
boxes: np.ndarray, boxes: np.ndarray,
iou_same: float = 0.6, iou_same: float = 0.6,
iou_diff: float = 0.98, iou_diff: float = 0.98,
) -> list[int]: ) -> List[int]:
"""NMS with separate IoU thresholds for same-class and cross-class overlaps. """NMS with separate IoU thresholds for same-class and cross-class overlaps.
Args: Args:
@@ -82,7 +83,7 @@ def nms(
""" """
scores = boxes[:, 1] scores = boxes[:, 1]
indices = np.argsort(scores)[::-1].tolist() indices = np.argsort(scores)[::-1].tolist()
selected: list[int] = [] selected: List[int] = []
while indices: while indices:
current = indices[0] current = indices[0]
@@ -113,10 +114,10 @@ _PRESERVE_LABELS = {"image", "seal", "chart"}
def check_containment( def check_containment(
boxes: np.ndarray, boxes: np.ndarray,
preserve_cls_ids: set | None = None, preserve_cls_ids: Optional[set] = None,
category_index: int | None = None, category_index: Optional[int] = None,
mode: str | None = None, mode: Optional[str] = None,
) -> tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Compute containment flags for each box. """Compute containment flags for each box.
Args: Args:
@@ -159,10 +160,9 @@ def check_containment(
# Box expansion (unclip) # Box expansion (unclip)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def unclip_boxes( def unclip_boxes(
boxes: np.ndarray, boxes: np.ndarray,
unclip_ratio: float | tuple[float, float] | dict | list | None, unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
) -> np.ndarray: ) -> np.ndarray:
"""Expand bounding boxes by the given ratio. """Expand bounding boxes by the given ratio.
@@ -215,14 +215,13 @@ def unclip_boxes(
# Main entry-point # Main entry-point
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def apply_layout_postprocess( def apply_layout_postprocess(
boxes: list[dict], boxes: List[Dict],
img_size: tuple[int, int], img_size: Tuple[int, int],
layout_nms: bool = True, layout_nms: bool = True,
layout_unclip_ratio: float | tuple | dict | None = None, layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
layout_merge_bboxes_mode: str | dict | None = "large", layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
) -> list[dict]: ) -> List[Dict]:
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results. """Apply GLM-OCR layout post-processing to PaddleOCR detection results.
Args: Args:
@@ -251,7 +250,7 @@ def apply_layout_postprocess(
arr_rows.append([cls_id, score, x1, y1, x2, y2]) arr_rows.append([cls_id, score, x1, y1, x2, y2])
boxes_array = np.array(arr_rows, dtype=float) boxes_array = np.array(arr_rows, dtype=float)
all_labels: list[str] = [b.get("label", "") for b in boxes] all_labels: List[str] = [b.get("label", "") for b in boxes]
# 1. NMS ---------------------------------------------------------------- # # 1. NMS ---------------------------------------------------------------- #
if layout_nms and len(boxes_array) > 1: if layout_nms and len(boxes_array) > 1:
@@ -263,14 +262,17 @@ def apply_layout_postprocess(
if len(boxes_array) > 1: if len(boxes_array) > 1:
img_area = img_width * img_height img_area = img_width * img_height
area_thres = 0.82 if img_width > img_height else 0.93 area_thres = 0.82 if img_width > img_height else 0.93
image_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl == "image"
}
keep_mask = np.ones(len(boxes_array), dtype=bool) keep_mask = np.ones(len(boxes_array), dtype=bool)
for i, lbl in enumerate(all_labels): for i, lbl in enumerate(all_labels):
if lbl == "image": if lbl == "image":
x1, y1, x2, y2 = boxes_array[i, 2:6] x1, y1, x2, y2 = boxes_array[i, 2:6]
x1 = max(0.0, x1) x1 = max(0.0, x1); y1 = max(0.0, y1)
y1 = max(0.0, y1) x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
x2 = min(float(img_width), x2)
y2 = min(float(img_height), y2)
if (x2 - x1) * (y2 - y1) > area_thres * img_area: if (x2 - x1) * (y2 - y1) > area_thres * img_area:
keep_mask[i] = False keep_mask[i] = False
boxes_array = boxes_array[keep_mask] boxes_array = boxes_array[keep_mask]
@@ -279,7 +281,9 @@ def apply_layout_postprocess(
# 3. Containment analysis (merge_bboxes_mode) -------------------------- # # 3. Containment analysis (merge_bboxes_mode) -------------------------- #
if layout_merge_bboxes_mode and len(boxes_array) > 1: if layout_merge_bboxes_mode and len(boxes_array) > 1:
preserve_cls_ids = { preserve_cls_ids = {
int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl in _PRESERVE_LABELS
} }
if isinstance(layout_merge_bboxes_mode, str): if isinstance(layout_merge_bboxes_mode, str):
@@ -317,7 +321,7 @@ def apply_layout_postprocess(
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio) boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
# 5. Clamp to image boundaries + skip invalid -------------------------- # # 5. Clamp to image boundaries + skip invalid -------------------------- #
result: list[dict] = [] result: List[Dict] = []
for i, row in enumerate(boxes_array): for i, row in enumerate(boxes_array):
cls_id = int(row[0]) cls_id = int(row[0])
score = float(row[1]) score = float(row[1])
@@ -329,13 +333,11 @@ def apply_layout_postprocess(
if x1 >= x2 or y1 >= y2: if x1 >= x2 or y1 >= y2:
continue continue
result.append( result.append({
{ "cls_id": cls_id,
"cls_id": cls_id, "label": all_labels[i],
"label": all_labels[i], "score": score,
"score": score, "coordinate": [int(x1), int(y1), int(x2), int(y2)],
"coordinate": [int(x1), int(y1), int(x2), int(y2)], })
}
)
return result return result

View File

@@ -878,9 +878,12 @@ class GLMOCREndToEndService(OCRServiceBase):
Returns: Returns:
Dict with 'markdown', 'latex', 'mathml', 'mml' keys. Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
""" """
# 1. Layout detection # 1. Padding
img_h, img_w = image.shape[:2] padded = self.image_processor.add_padding(image)
layout_info = self.layout_detector.detect(image) img_h, img_w = padded.shape[:2]
# 2. Layout detection
layout_info = self.layout_detector.detect(padded)
# Sort regions in reading order: top-to-bottom, left-to-right # Sort regions in reading order: top-to-bottom, left-to-right
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0])) layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
@@ -889,7 +892,7 @@ class GLMOCREndToEndService(OCRServiceBase):
if not layout_info.regions: if not layout_info.regions:
# No layout detected → assume it's a formula, use formula recognition # No layout detected → assume it's a formula, use formula recognition
logger.info("No layout regions detected, treating image as formula") logger.info("No layout regions detected, treating image as formula")
raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"]) raw_content = self._call_vllm(padded, _TASK_PROMPTS["formula"])
# Format as display formula markdown # Format as display formula markdown
formatted_content = raw_content.strip() formatted_content = raw_content.strip()
if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")): if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")):
@@ -902,7 +905,7 @@ class GLMOCREndToEndService(OCRServiceBase):
if region.type == "figure": if region.type == "figure":
continue continue
x1, y1, x2, y2 = (int(c) for c in region.bbox) x1, y1, x2, y2 = (int(c) for c in region.bbox)
cropped = image[y1:y2, x1:x2] cropped = padded[y1:y2, x1:x2]
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10: if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
logger.warning( logger.warning(
"Skipping region idx=%d (label=%s): crop too small %s", "Skipping region idx=%d (label=%s): crop too small %s",
@@ -915,7 +918,7 @@ class GLMOCREndToEndService(OCRServiceBase):
tasks.append((idx, region, cropped, prompt)) tasks.append((idx, region, cropped, prompt))
if not tasks: if not tasks:
raw_content = self._call_vllm(image, _DEFAULT_PROMPT) raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content) markdown_content = self._formatter._clean_content(raw_content)
else: else:
# Parallel OCR calls # Parallel OCR calls
@@ -962,3 +965,17 @@ class GLMOCREndToEndService(OCRServiceBase):
logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e) logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e)
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml} return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/formula2.jpg")
image_numpy = np.array(image)
# Encode image to bytes (as done in API layer)
success, encoded_image = cv2.imencode(".png", image_numpy)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0)
ocr_result = mineru_service.recognize(image_bytes)
print(ocr_result)

View File

@@ -11,7 +11,7 @@ authors = [
dependencies = [ dependencies = [
"fastapi==0.128.0", "fastapi==0.128.0",
"uvicorn[standard]==0.40.0", "uvicorn[standard]==0.40.0",
"opencv-python==4.12.0.88", "opencv-python-headless==4.12.0.88", # headless: no Qt/FFmpeg GUI, server-only
"python-multipart==0.0.21", "python-multipart==0.0.21",
"pydantic==2.12.5", "pydantic==2.12.5",
"pydantic-settings==2.12.0", "pydantic-settings==2.12.0",
@@ -20,7 +20,6 @@ dependencies = [
"pillow==12.0.0", "pillow==12.0.0",
"python-docx==1.2.0", "python-docx==1.2.0",
"paddleocr==3.4.0", "paddleocr==3.4.0",
"doclayout-yolo==0.0.4",
"latex2mathml==3.78.1", "latex2mathml==3.78.1",
"paddle==1.2.0", "paddle==1.2.0",
"pypandoc==1.16.2", "pypandoc==1.16.2",

View File

@@ -1,4 +1,5 @@
import numpy as np import numpy as np
import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -34,9 +35,7 @@ def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
client = _build_client() client = _build_client()
missing = client.post("/ocr", json={}) missing = client.post("/ocr", json={})
both = client.post( both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
"/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}
)
assert missing.status_code == 422 assert missing.status_code == 422
assert both.status_code == 422 assert both.status_code == 422

View File

@@ -57,22 +57,12 @@ def test_merge_formula_numbers_merges_before_and_after_formula():
before = formatter._merge_formula_numbers( before = formatter._merge_formula_numbers(
[ [
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"}, {"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
{ {"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
"index": 1,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y\n$$",
},
] ]
) )
after = formatter._merge_formula_numbers( after = formatter._merge_formula_numbers(
[ [
{ {"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y\n$$",
},
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"}, {"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
] ]
) )

View File

@@ -23,9 +23,7 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
calls = {} calls = {}
def fake_apply_layout_postprocess( def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode
):
calls["args"] = { calls["args"] = {
"boxes": boxes, "boxes": boxes,
"img_size": img_size, "img_size": img_size,
@@ -35,9 +33,7 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
} }
return [boxes[0], boxes[2]] return [boxes[0], boxes[2]]
monkeypatch.setattr( monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
"app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess
)
image = np.zeros((200, 100, 3), dtype=np.uint8) image = np.zeros((200, 100, 3), dtype=np.uint8)
info = detector.detect(image) info = detector.detect(image)

View File

@@ -146,4 +146,6 @@ def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image()
layout_merge_bboxes_mode=None, layout_merge_bboxes_mode=None,
) )
assert result == [{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}] assert result == [
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
]

View File

@@ -46,9 +46,7 @@ def test_encode_region_returns_decodable_base64_jpeg():
image[:, :] = [0, 128, 255] image[:, :] = [0, 128, 255]
encoded = service._encode_region(image) encoded = service._encode_region(image)
decoded = cv2.imdecode( decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR
)
assert decoded.shape[:2] == image.shape[:2] assert decoded.shape[:2] == image.shape[:2]
@@ -73,9 +71,7 @@ def test_call_vllm_builds_messages_and_returns_content():
assert captured["model"] == "glm-ocr" assert captured["model"] == "glm-ocr"
assert captured["max_tokens"] == 1024 assert captured["max_tokens"] == 1024
assert captured["messages"][0]["content"][0]["type"] == "image_url" assert captured["messages"][0]["content"][0]["type"] == "image_url"
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith( assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
"data:image/jpeg;base64,"
)
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"} assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
@@ -102,19 +98,9 @@ def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch): def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
regions = [ regions = [
LayoutRegion( LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9 LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
), LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
LayoutRegion(
type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8
),
LayoutRegion(
type="formula",
native_label="display_formula",
bbox=[20, 20, 40, 40],
confidence=0.95,
score=0.95,
),
] ]
service = _build_service(regions=regions) service = _build_service(regions=regions)
image = np.zeros((40, 40, 3), dtype=np.uint8) image = np.zeros((40, 40, 3), dtype=np.uint8)