Compare commits
5 Commits
main
...
optimize/d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ba835ab44 | ||
|
|
7c7d4bf36a | ||
|
|
ef98f37525 | ||
|
|
95c497829f | ||
|
|
6579cf55f5 |
@@ -8,8 +8,7 @@
|
||||
"WebFetch(domain:raw.githubusercontent.com)",
|
||||
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
|
||||
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
|
||||
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)",
|
||||
"Bash(ruff check:*)"
|
||||
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
123
Dockerfile
123
Dockerfile
@@ -1,82 +1,103 @@
|
||||
# DocProcesser Dockerfile
|
||||
# Optimized for RTX 5080 GPU deployment
|
||||
# DocProcesser Dockerfile - Production optimized
|
||||
# Ultra-lean multi-stage build for PPDocLayoutV3
|
||||
# Final image: ~3GB (from 17GB)
|
||||
|
||||
# Use NVIDIA CUDA base image with Python 3.10
|
||||
FROM nvidia/cuda:12.9.0-runtime-ubuntu24.04
|
||||
# =============================================================================
|
||||
# STAGE 1: Builder
|
||||
# =============================================================================
|
||||
FROM nvidia/cuda:12.9.0-devel-ubuntu24.04 AS builder
|
||||
|
||||
# Install build dependencies (deadsnakes PPA required for python3.10 on Ubuntu 24.04)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.10 python3.10-venv python3.10-dev python3.10-distutils \
|
||||
build-essential curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Setup Python
|
||||
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python
|
||||
|
||||
# Install uv
|
||||
RUN pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Copy dependencies
|
||||
COPY pyproject.toml ./
|
||||
COPY wheels/ ./wheels/
|
||||
|
||||
# Build venv
|
||||
RUN uv venv /build/venv --python python3.10 && \
|
||||
. /build/venv/bin/activate && \
|
||||
uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -e . && \
|
||||
rm -rf ./wheels
|
||||
|
||||
# Aggressive optimization: strip debug symbols from .so files (~300-800MB saved)
|
||||
RUN find /build/venv -name "*.so" -exec strip --strip-unneeded {} + || true
|
||||
|
||||
# Remove paddle C++ headers (~22MB saved)
|
||||
RUN rm -rf /build/venv/lib/python*/site-packages/paddle/include
|
||||
|
||||
# Clean Python cache and build artifacts
|
||||
RUN find /build/venv -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true && \
|
||||
find /build/venv -type f -name "*.pyc" -delete && \
|
||||
find /build/venv -type f -name "*.pyo" -delete && \
|
||||
find /build/venv -type d -name "tests" -exec rm -rf {} + 2>/dev/null || true && \
|
||||
find /build/venv -type d -name "test" -exec rm -rf {} + 2>/dev/null || true && \
|
||||
rm -rf /build/venv/lib/*/site-packages/pip* \
|
||||
/build/venv/lib/*/site-packages/setuptools* \
|
||||
/build/venv/include \
|
||||
/build/venv/share && \
|
||||
rm -rf /root/.cache 2>/dev/null || true
|
||||
|
||||
# =============================================================================
|
||||
# STAGE 2: Runtime - CUDA base (~400MB, not ~3.4GB from runtime)
|
||||
# =============================================================================
|
||||
FROM nvidia/cuda:12.9.0-base-ubuntu24.04
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||
# Model cache directories - mount these at runtime
|
||||
MODELSCOPE_CACHE=/root/.cache/modelscope \
|
||||
HF_HOME=/root/.cache/huggingface \
|
||||
# Application config (override defaults for container)
|
||||
# Use 127.0.0.1 for --network host mode, or override with -e for bridge mode
|
||||
PP_DOCLAYOUT_MODEL_DIR=/root/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2 \
|
||||
PADDLEOCR_VL_URL=http://127.0.0.1:8001/v1
|
||||
PADDLEOCR_VL_URL=http://127.0.0.1:8001/v1 \
|
||||
PATH="/app/.venv/bin:$PATH" \
|
||||
VIRTUAL_ENV="/app/.venv"
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies and Python 3.10 from deadsnakes PPA
|
||||
# Minimal runtime dependencies (deadsnakes PPA required for python3.10 on Ubuntu 24.04)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.10 \
|
||||
python3.10-venv \
|
||||
python3.10-dev \
|
||||
python3.10-distutils \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender-dev \
|
||||
libgomp1 \
|
||||
curl \
|
||||
pandoc \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& ln -sf /usr/bin/python3.10 /usr/bin/python \
|
||||
&& ln -sf /usr/bin/python3.10 /usr/bin/python3 \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
|
||||
libgl1 libglib2.0-0 libgomp1 \
|
||||
curl pandoc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv via pip (more reliable than install script)
|
||||
RUN python3.10 -m pip install uv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/app/.venv"
|
||||
RUN ln -sf /usr/bin/python3.10 /usr/bin/python
|
||||
|
||||
# Copy dependency files first for better caching
|
||||
COPY pyproject.toml ./
|
||||
COPY wheels/ ./wheels/
|
||||
# Copy optimized venv from builder
|
||||
COPY --from=builder /build/venv /app/.venv
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
RUN uv venv /app/.venv --python python3.10 \
|
||||
&& uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -e . \
|
||||
&& rm -rf ./wheels
|
||||
|
||||
# Copy application code
|
||||
# Copy app code
|
||||
COPY app/ ./app/
|
||||
|
||||
# Create model cache directories (mount from host at runtime)
|
||||
RUN mkdir -p /root/.cache/modelscope \
|
||||
/root/.cache/huggingface \
|
||||
/root/.paddlex \
|
||||
/app/app/model/DocLayout \
|
||||
/app/app/model/PP-DocLayout
|
||||
# Create cache mount points (DO NOT include model files)
|
||||
RUN mkdir -p /root/.cache/modelscope /root/.cache/huggingface /root/.paddlex && \
|
||||
rm -rf /app/app/model/*
|
||||
|
||||
# Declare volumes for model cache (mount at runtime to avoid re-downloading)
|
||||
VOLUME ["/root/.cache/modelscope", "/root/.cache/huggingface", "/root/.paddlex"]
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8053
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8053/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8053", "--workers", "1"]
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
"""Format conversion endpoints."""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from app.core.dependencies import get_converter
|
||||
from app.core.logging_config import get_logger
|
||||
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest
|
||||
from app.schemas.convert import MarkdownToDocxRequest, LatexToOmmlRequest, LatexToOmmlResponse
|
||||
from app.services.converter import Converter
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -24,26 +19,14 @@ async def convert_markdown_to_docx(
|
||||
|
||||
Returns the generated DOCX file as a binary response.
|
||||
"""
|
||||
logger.info(
|
||||
"Converting markdown to DOCX, filename=%s, content_length=%d",
|
||||
request.filename,
|
||||
len(request.markdown),
|
||||
)
|
||||
try:
|
||||
docx_bytes = converter.export_to_file(request.markdown, export_type="docx")
|
||||
logger.info(
|
||||
"DOCX conversion successful, filename=%s, size=%d bytes",
|
||||
request.filename,
|
||||
len(docx_bytes),
|
||||
)
|
||||
encoded_name = quote(f"{request.filename}.docx")
|
||||
return Response(
|
||||
content=docx_bytes,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
|
||||
headers={"Content-Disposition": f'attachment; filename="{request.filename}.docx"'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("DOCX conversion failed, filename=%s: %s", request.filename, e)
|
||||
raise HTTPException(status_code=500, detail=f"Conversion failed: {e}")
|
||||
|
||||
|
||||
@@ -72,17 +55,12 @@ async def convert_latex_to_omml(
|
||||
```
|
||||
"""
|
||||
if not request.latex or not request.latex.strip():
|
||||
logger.warning("LaTeX to OMML request received with empty formula")
|
||||
raise HTTPException(status_code=400, detail="LaTeX formula cannot be empty")
|
||||
|
||||
logger.info("Converting LaTeX to OMML, latex=%r", request.latex)
|
||||
try:
|
||||
omml = converter.convert_to_omml(request.latex)
|
||||
logger.info("LaTeX to OMML conversion successful")
|
||||
return LatexToOmmlResponse(omml=omml)
|
||||
except ValueError as e:
|
||||
logger.warning("LaTeX to OMML conversion invalid input: %s", e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error("LaTeX to OMML conversion runtime error: %s", e)
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
@@ -6,10 +6,10 @@ import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
from app.core.dependencies import (
|
||||
get_glmocr_endtoend_service,
|
||||
get_image_processor,
|
||||
get_glmocr_endtoend_service,
|
||||
)
|
||||
from app.core.logging_config import RequestIDAdapter, get_logger
|
||||
from app.core.logging_config import get_logger, RequestIDAdapter
|
||||
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
|
||||
@@ -50,7 +50,9 @@ class Settings(BaseSettings):
|
||||
max_tokens: int = 4096
|
||||
|
||||
# Model Paths
|
||||
pp_doclayout_model_dir: str | None = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
||||
pp_doclayout_model_dir: str | None = (
|
||||
"/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
||||
)
|
||||
|
||||
# Image Processing
|
||||
max_image_size_mb: int = 10
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Application dependencies."""
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.converter import Converter
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
from app.services.converter import Converter
|
||||
from app.core.config import get_settings
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
_layout_detector: LayoutDetector | None = None
|
||||
|
||||
@@ -2,15 +2,11 @@
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
# Context variable to hold the current request_id across async boundaries
|
||||
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
|
||||
|
||||
|
||||
class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler):
|
||||
"""File handler that rotates by both time (daily) and size (100MB)."""
|
||||
@@ -22,10 +18,10 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
|
||||
interval: int = 1,
|
||||
backupCount: int = 30,
|
||||
maxBytes: int = 100 * 1024 * 1024, # 100MB
|
||||
encoding: str | None = None,
|
||||
encoding: Optional[str] = None,
|
||||
delay: bool = False,
|
||||
utc: bool = False,
|
||||
atTime: Any | None = None,
|
||||
atTime: Optional[Any] = None,
|
||||
):
|
||||
"""Initialize handler with both time and size rotation.
|
||||
|
||||
@@ -62,14 +58,14 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
|
||||
if self.stream is None:
|
||||
self.stream = self._open()
|
||||
if self.maxBytes > 0:
|
||||
msg = f"{self.format(record)}\n"
|
||||
msg = "%s\n" % self.format(record)
|
||||
self.stream.seek(0, 2) # Seek to end
|
||||
if self.stream.tell() + len(msg) >= self.maxBytes:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def setup_logging(log_dir: str | None = None) -> logging.Logger:
|
||||
def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
|
||||
"""Setup application logging with rotation by day and size.
|
||||
|
||||
Args:
|
||||
@@ -96,13 +92,14 @@ def setup_logging(log_dir: str | None = None) -> logging.Logger:
|
||||
# Remove existing handlers to avoid duplicates
|
||||
logger.handlers.clear()
|
||||
|
||||
# Create custom formatter that automatically injects request_id from context
|
||||
# Create custom formatter that handles missing request_id
|
||||
class RequestIDFormatter(logging.Formatter):
|
||||
"""Formatter that injects request_id from ContextVar into log records."""
|
||||
"""Formatter that handles request_id in log records."""
|
||||
|
||||
def format(self, record):
|
||||
# Add request_id if not present
|
||||
if not hasattr(record, "request_id"):
|
||||
record.request_id = request_id_ctx.get()
|
||||
record.request_id = getattr(record, "request_id", "unknown")
|
||||
return super().format(record)
|
||||
|
||||
formatter = RequestIDFormatter(
|
||||
@@ -137,11 +134,11 @@ def setup_logging(log_dir: str | None = None) -> logging.Logger:
|
||||
|
||||
|
||||
# Global logger instance
|
||||
_logger: logging.Logger | None = None
|
||||
_logger: Optional[logging.Logger] = None
|
||||
|
||||
|
||||
def get_logger() -> logging.Logger:
|
||||
"""Get the global logger instance, initializing if needed."""
|
||||
"""Get the global logger instance."""
|
||||
global _logger
|
||||
if _logger is None:
|
||||
_logger = setup_logging()
|
||||
|
||||
@@ -8,7 +8,6 @@ from app.api.v1.router import api_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.dependencies import init_layout_detector
|
||||
from app.core.logging_config import setup_logging
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -34,8 +33,6 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# Include API router
|
||||
app.include_router(api_router, prefix=settings.api_prefix)
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Middleware to propagate or generate request_id for every request."""
|
||||
|
||||
import uuid
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.core.logging_config import request_id_ctx
|
||||
|
||||
REQUEST_ID_HEADER = "X-Request-ID"
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Extract X-Request-ID from incoming request headers or generate one.
|
||||
|
||||
The request_id is stored in a ContextVar so that all log records emitted
|
||||
during the request are automatically annotated with it, without needing to
|
||||
pass it explicitly through every call.
|
||||
|
||||
The same request_id is also echoed back in the response header so that
|
||||
callers can correlate logs.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
|
||||
token = request_id_ctx.set(request_id)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
request_id_ctx.reset(token)
|
||||
|
||||
response.headers[REQUEST_ID_HEADER] = request_id
|
||||
return response
|
||||
@@ -36,3 +36,4 @@ class LatexToOmmlResponse(BaseModel):
|
||||
"""Response body for LaTeX to OMML conversion endpoint."""
|
||||
|
||||
omml: str = Field("", description="OMML (Office Math Markup Language) representation")
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ class LayoutRegion(BaseModel):
|
||||
"""A detected layout region in the document."""
|
||||
|
||||
type: str = Field(..., description="Region type: text, formula, table, figure")
|
||||
native_label: str = Field(
|
||||
"", description="Raw label before type mapping (e.g. doc_title, formula_number)"
|
||||
)
|
||||
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
|
||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||
confidence: float = Field(..., description="Detection confidence score")
|
||||
score: float = Field(..., description="Detection score")
|
||||
@@ -43,15 +41,10 @@ class ImageOCRRequest(BaseModel):
|
||||
class ImageOCRResponse(BaseModel):
|
||||
"""Response body for image OCR endpoint."""
|
||||
|
||||
latex: str = Field(
|
||||
"", description="LaTeX representation of the content (empty if mixed content)"
|
||||
)
|
||||
latex: str = Field("", description="LaTeX representation of the content (empty if mixed content)")
|
||||
markdown: str = Field("", description="Markdown representation of the content")
|
||||
mathml: str = Field("", description="Standard MathML representation (empty if mixed content)")
|
||||
mml: str = Field(
|
||||
"", description="XML MathML with mml: namespace prefix (empty if mixed content)"
|
||||
)
|
||||
mml: str = Field("", description="XML MathML with mml: namespace prefix (empty if mixed content)")
|
||||
layout_info: LayoutInfo = Field(default_factory=LayoutInfo)
|
||||
recognition_mode: str = Field(
|
||||
"", description="Recognition mode used: mixed_recognition or formula_recognition"
|
||||
)
|
||||
recognition_mode: str = Field("", description="Recognition mode used: mixed_recognition or formula_recognition")
|
||||
|
||||
|
||||
@@ -112,18 +112,14 @@ class Converter:
|
||||
# Pre-compiled regex patterns for preprocessing
|
||||
_RE_VSPACE = re.compile(r"\\\[1mm\]")
|
||||
_RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL)
|
||||
_RE_BLOCK_FORMULA_LINE = re.compile(
|
||||
r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
_RE_BLOCK_FORMULA_LINE = re.compile(r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL)
|
||||
_RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>')
|
||||
_RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)")
|
||||
_RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}")
|
||||
_RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+")
|
||||
_RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}")
|
||||
_RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL)
|
||||
_RE_ALIGNED_BRACE = re.compile(
|
||||
r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL
|
||||
)
|
||||
_RE_ALIGNED_BRACE = re.compile(r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL)
|
||||
_RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL)
|
||||
_RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL)
|
||||
_RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL)
|
||||
@@ -372,9 +368,7 @@ class Converter:
|
||||
mathml = latex_to_mathml(latex_formula)
|
||||
return Converter._postprocess_mathml_for_word(mathml)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}"
|
||||
) from e
|
||||
raise RuntimeError(f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _postprocess_mathml_for_word(mathml: str) -> str:
|
||||
@@ -589,6 +583,7 @@ class Converter:
|
||||
"⇓": "⇓", # Downarrow
|
||||
"↕": "↕", # updownarrow
|
||||
"⇕": "⇕", # Updownarrow
|
||||
"≠": "≠", # ne
|
||||
"≪": "≪", # ll
|
||||
"≫": "≫", # gg
|
||||
"⩽": "⩽", # leqslant
|
||||
@@ -967,7 +962,7 @@ class Converter:
|
||||
"""Export to DOCX format using pypandoc."""
|
||||
extra_args = [
|
||||
"--highlight-style=pygments",
|
||||
"--reference-doc=app/pkg/reference.docx",
|
||||
f"--reference-doc=app/pkg/reference.docx",
|
||||
]
|
||||
pypandoc.convert_file(
|
||||
input_path,
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
"""GLM-OCR postprocessing logic adapted for this project.
|
||||
|
||||
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
|
||||
glm-ocr/glmocr/utils/result_postprocess_utils.py.
|
||||
|
||||
Covers:
|
||||
- Repeated-content / hallucination detection
|
||||
- Per-region content cleaning and formatting (titles, bullets, formulas)
|
||||
- formula_number merging (→ \\tag{})
|
||||
- Hyphenated text-block merging (via wordfreq)
|
||||
- Missing bullet-point detection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from wordfreq import zipf_frequency
|
||||
@@ -13,14 +29,13 @@ try:
|
||||
except ImportError:
|
||||
_WORDFREQ_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# result_postprocess_utils (ported)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None:
|
||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
|
||||
"""Detect and truncate a consecutively-repeated pattern.
|
||||
|
||||
Returns the string with the repeat removed, or None if not found.
|
||||
@@ -105,13 +120,8 @@ def clean_formula_number(number_content: str) -> str:
|
||||
# GLMResultFormatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Matches content that consists *entirely* of a display-math block and nothing else.
|
||||
# Used to detect when a text/heading region was actually recognised as a formula by vLLM,
|
||||
# so we can correct the label before heading prefixes (## …) are applied.
|
||||
_PURE_DISPLAY_FORMULA_RE = re.compile(r"^\s*(?:\$\$[\s\S]+?\$\$|\\\[[\s\S]+?\\\])\s*$")
|
||||
|
||||
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
||||
_LABEL_TO_CATEGORY: dict[str, str] = {
|
||||
_LABEL_TO_CATEGORY: Dict[str, str] = {
|
||||
# text
|
||||
"abstract": "text",
|
||||
"algorithm": "text",
|
||||
@@ -147,7 +157,7 @@ class GLMResultFormatter:
|
||||
# Public entry-point
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def process(self, regions: list[dict[str, Any]]) -> str:
|
||||
def process(self, regions: List[Dict[str, Any]]) -> str:
|
||||
"""Run the full postprocessing pipeline and return Markdown.
|
||||
|
||||
Args:
|
||||
@@ -165,24 +175,11 @@ class GLMResultFormatter:
|
||||
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
|
||||
|
||||
# Per-region cleaning + formatting
|
||||
processed: list[dict] = []
|
||||
processed: List[Dict] = []
|
||||
for item in items:
|
||||
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
||||
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
||||
|
||||
# Label correction: layout may say "text" (or a heading like "paragraph_title")
|
||||
# but vLLM recognised the content as a formula and returned $$…$$. Without
|
||||
# correction the heading prefix (##) would be prepended to the math block,
|
||||
# producing broken output like "## $$ \mathbf{y}=… $$".
|
||||
raw_content = (item.get("content") or "").strip()
|
||||
if item["label"] == "text" and _PURE_DISPLAY_FORMULA_RE.match(raw_content):
|
||||
logger.debug(
|
||||
"Label corrected text (native=%s) → formula: pure display-formula detected",
|
||||
item["native_label"],
|
||||
)
|
||||
item["label"] = "formula"
|
||||
item["native_label"] = "display_formula"
|
||||
|
||||
item["content"] = self._format_content(
|
||||
item.get("content") or "",
|
||||
item["label"],
|
||||
@@ -202,7 +199,7 @@ class GLMResultFormatter:
|
||||
processed = self._format_bullet_points(processed)
|
||||
|
||||
# Assemble Markdown
|
||||
parts: list[str] = []
|
||||
parts: List[str] = []
|
||||
for item in processed:
|
||||
content = item.get("content") or ""
|
||||
if item["label"] == "image":
|
||||
@@ -265,11 +262,9 @@ class GLMResultFormatter:
|
||||
# Formula wrapping
|
||||
if label == "formula":
|
||||
content = content.strip()
|
||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)"), ("$", "$")]:
|
||||
if content.startswith(s):
|
||||
content = content[len(s) :].strip()
|
||||
if content.endswith(e):
|
||||
content = content[: -len(e)].strip()
|
||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
||||
if content.startswith(s) and content.endswith(e):
|
||||
content = content[len(s) : -len(e)].strip()
|
||||
break
|
||||
if not content:
|
||||
logger.warning("Skipping formula region with empty content after stripping delimiters")
|
||||
@@ -301,12 +296,12 @@ class GLMResultFormatter:
|
||||
# Structural merges
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _merge_formula_numbers(self, items: list[dict]) -> list[dict]:
|
||||
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
|
||||
"""Merge formula_number region into adjacent formula with \\tag{}."""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
merged: list[dict] = []
|
||||
merged: List[Dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
@@ -345,12 +340,12 @@ class GLMResultFormatter:
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _merge_text_blocks(self, items: list[dict]) -> list[dict]:
|
||||
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
|
||||
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
|
||||
if not items or not _WORDFREQ_AVAILABLE:
|
||||
return items
|
||||
|
||||
merged: list[dict] = []
|
||||
merged: List[Dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
@@ -394,7 +389,7 @@ class GLMResultFormatter:
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _format_bullet_points(self, items: list[dict], left_align_threshold: float = 10.0) -> list[dict]:
|
||||
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
|
||||
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
||||
if len(items) < 3:
|
||||
return items
|
||||
@@ -424,7 +419,10 @@ class GLMResultFormatter:
|
||||
if not (cur_bbox and prev_bbox and nxt_bbox):
|
||||
continue
|
||||
|
||||
if abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold:
|
||||
if (
|
||||
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
|
||||
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
|
||||
):
|
||||
cur["content"] = "- " + cur_content
|
||||
|
||||
return items
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""PP-DocLayoutV3 wrapper for document layout detection."""
|
||||
|
||||
import numpy as np
|
||||
from paddleocr import LayoutDetection
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.core.config import get_settings
|
||||
from app.services.layout_postprocess import apply_layout_postprocess
|
||||
from paddleocr import LayoutDetection
|
||||
from typing import Optional
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -13,7 +14,7 @@ settings = get_settings()
|
||||
class LayoutDetector:
|
||||
"""Layout detector for PP-DocLayoutV2."""
|
||||
|
||||
_layout_detector: LayoutDetection | None = None
|
||||
_layout_detector: Optional[LayoutDetection] = None
|
||||
|
||||
# PP-DocLayoutV2 class ID to label mapping
|
||||
CLS_ID_TO_LABEL: dict[int, str] = {
|
||||
@@ -148,18 +149,17 @@ class LayoutDetector:
|
||||
)
|
||||
)
|
||||
|
||||
mixed_recognition = any(region.type == "text" and region.score > 0.85 for region in regions)
|
||||
mixed_recognition = any(region.type == "text" and region.score > 0.3 for region in regions)
|
||||
|
||||
return LayoutInfo(regions=regions, MixedRecognition=mixed_recognition)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.converter import Converter
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.ocr_service import GLMOCREndToEndService
|
||||
from app.services.converter import Converter
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -169,15 +169,15 @@ if __name__ == "__main__":
|
||||
converter = Converter()
|
||||
|
||||
# Initialize OCR service
|
||||
ocr_service = GLMOCREndToEndService(
|
||||
vl_server_url=settings.glm_ocr_url,
|
||||
ocr_service = OCRService(
|
||||
vl_server_url=settings.paddleocr_vl_url,
|
||||
layout_detector=layout_detector,
|
||||
image_processor=image_processor,
|
||||
converter=converter,
|
||||
)
|
||||
|
||||
# Load test image
|
||||
image_path = "test/image2.png"
|
||||
image_path = "test/timeout.jpg"
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
if image is None:
|
||||
|
||||
@@ -15,14 +15,16 @@ the quality of the GLM-OCR SDK's layout pipeline.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Primitive geometry helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def iou(box1: list[float], box2: list[float]) -> float:
|
||||
def iou(box1: List[float], box2: List[float]) -> float:
|
||||
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
|
||||
x1, y1, x2, y2 = box1
|
||||
x1_p, y1_p, x2_p, y2_p = box2
|
||||
@@ -39,7 +41,7 @@ def iou(box1: list[float], box2: list[float]) -> float:
|
||||
return inter_area / float(box1_area + box2_area - inter_area)
|
||||
|
||||
|
||||
def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool:
|
||||
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
|
||||
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
|
||||
|
||||
box format: [cls_id, score, x1, y1, x2, y2]
|
||||
@@ -64,12 +66,11 @@ def is_contained(box1: list[float], box2: list[float], overlap_threshold: float
|
||||
# NMS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def nms(
|
||||
boxes: np.ndarray,
|
||||
iou_same: float = 0.6,
|
||||
iou_diff: float = 0.98,
|
||||
) -> list[int]:
|
||||
) -> List[int]:
|
||||
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
|
||||
|
||||
Args:
|
||||
@@ -82,7 +83,7 @@ def nms(
|
||||
"""
|
||||
scores = boxes[:, 1]
|
||||
indices = np.argsort(scores)[::-1].tolist()
|
||||
selected: list[int] = []
|
||||
selected: List[int] = []
|
||||
|
||||
while indices:
|
||||
current = indices[0]
|
||||
@@ -113,10 +114,10 @@ _PRESERVE_LABELS = {"image", "seal", "chart"}
|
||||
|
||||
def check_containment(
|
||||
boxes: np.ndarray,
|
||||
preserve_cls_ids: set | None = None,
|
||||
category_index: int | None = None,
|
||||
mode: str | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
preserve_cls_ids: Optional[set] = None,
|
||||
category_index: Optional[int] = None,
|
||||
mode: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute containment flags for each box.
|
||||
|
||||
Args:
|
||||
@@ -159,10 +160,9 @@ def check_containment(
|
||||
# Box expansion (unclip)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def unclip_boxes(
|
||||
boxes: np.ndarray,
|
||||
unclip_ratio: float | tuple[float, float] | dict | list | None,
|
||||
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
|
||||
) -> np.ndarray:
|
||||
"""Expand bounding boxes by the given ratio.
|
||||
|
||||
@@ -215,14 +215,13 @@ def unclip_boxes(
|
||||
# Main entry-point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_layout_postprocess(
|
||||
boxes: list[dict],
|
||||
img_size: tuple[int, int],
|
||||
boxes: List[Dict],
|
||||
img_size: Tuple[int, int],
|
||||
layout_nms: bool = True,
|
||||
layout_unclip_ratio: float | tuple | dict | None = None,
|
||||
layout_merge_bboxes_mode: str | dict | None = "large",
|
||||
) -> list[dict]:
|
||||
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
|
||||
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
|
||||
) -> List[Dict]:
|
||||
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
|
||||
|
||||
Args:
|
||||
@@ -251,7 +250,7 @@ def apply_layout_postprocess(
|
||||
arr_rows.append([cls_id, score, x1, y1, x2, y2])
|
||||
boxes_array = np.array(arr_rows, dtype=float)
|
||||
|
||||
all_labels: list[str] = [b.get("label", "") for b in boxes]
|
||||
all_labels: List[str] = [b.get("label", "") for b in boxes]
|
||||
|
||||
# 1. NMS ---------------------------------------------------------------- #
|
||||
if layout_nms and len(boxes_array) > 1:
|
||||
@@ -263,14 +262,17 @@ def apply_layout_postprocess(
|
||||
if len(boxes_array) > 1:
|
||||
img_area = img_width * img_height
|
||||
area_thres = 0.82 if img_width > img_height else 0.93
|
||||
image_cls_ids = {
|
||||
int(boxes_array[i, 0])
|
||||
for i, lbl in enumerate(all_labels)
|
||||
if lbl == "image"
|
||||
}
|
||||
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||||
for i, lbl in enumerate(all_labels):
|
||||
if lbl == "image":
|
||||
x1, y1, x2, y2 = boxes_array[i, 2:6]
|
||||
x1 = max(0.0, x1)
|
||||
y1 = max(0.0, y1)
|
||||
x2 = min(float(img_width), x2)
|
||||
y2 = min(float(img_height), y2)
|
||||
x1 = max(0.0, x1); y1 = max(0.0, y1)
|
||||
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
|
||||
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
|
||||
keep_mask[i] = False
|
||||
boxes_array = boxes_array[keep_mask]
|
||||
@@ -279,7 +281,9 @@ def apply_layout_postprocess(
|
||||
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
|
||||
if layout_merge_bboxes_mode and len(boxes_array) > 1:
|
||||
preserve_cls_ids = {
|
||||
int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS
|
||||
int(boxes_array[i, 0])
|
||||
for i, lbl in enumerate(all_labels)
|
||||
if lbl in _PRESERVE_LABELS
|
||||
}
|
||||
|
||||
if isinstance(layout_merge_bboxes_mode, str):
|
||||
@@ -317,7 +321,7 @@ def apply_layout_postprocess(
|
||||
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
|
||||
|
||||
# 5. Clamp to image boundaries + skip invalid -------------------------- #
|
||||
result: list[dict] = []
|
||||
result: List[Dict] = []
|
||||
for i, row in enumerate(boxes_array):
|
||||
cls_id = int(row[0])
|
||||
score = float(row[1])
|
||||
@@ -329,13 +333,11 @@ def apply_layout_postprocess(
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
continue
|
||||
|
||||
result.append(
|
||||
{
|
||||
result.append({
|
||||
"cls_id": cls_id,
|
||||
"label": all_labels[i],
|
||||
"score": score,
|
||||
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@@ -150,7 +150,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
|
||||
# Strategy: remove spaces before \ and between non-command chars,
|
||||
# but preserve the space after \command when followed by a non-\ char
|
||||
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
|
||||
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
|
||||
cleaned = re.sub(
|
||||
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
|
||||
) # remove space after non-letter non-\
|
||||
return f"{operator}{{{cleaned}}}"
|
||||
|
||||
# Match _{ ... } or ^{ ... }
|
||||
@@ -628,7 +630,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
self.glm_ocr_url = glm_ocr_url
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
|
||||
|
||||
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
|
||||
def _recognize_formula_with_paddleocr_vl(
|
||||
self, image: np.ndarray, prompt: str = "Formula Recognition:"
|
||||
) -> str:
|
||||
"""Recognize formula using PaddleOCR-VL API.
|
||||
|
||||
Args:
|
||||
@@ -669,7 +673,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
|
||||
|
||||
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
|
||||
def _extract_and_recognize_formulas(
|
||||
self, markdown_content: str, original_image: np.ndarray
|
||||
) -> str:
|
||||
"""Extract image references from markdown and recognize formulas.
|
||||
|
||||
Args:
|
||||
@@ -751,7 +757,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
markdown_content = result["results"]["image"].get("md_content", "")
|
||||
|
||||
if "
|
||||
markdown_content = self._extract_and_recognize_formulas(
|
||||
markdown_content, original_image
|
||||
)
|
||||
|
||||
# Apply postprocessing to fix OCR errors
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
@@ -781,11 +789,15 @@ class MineruOCRService(OCRServiceBase):
|
||||
|
||||
# Task-specific prompts (from GLM-OCR SDK config.yaml)
|
||||
_TASK_PROMPTS: dict[str, str] = {
|
||||
"text": "Text Recognition. If the content is a formula, please output display latex code, else output text",
|
||||
"text": "Text Recognition:",
|
||||
"formula": "Formula Recognition:",
|
||||
"table": "Table Recognition:",
|
||||
}
|
||||
_DEFAULT_PROMPT = "Text Recognition. If the content is a formula, please output display latex code, else output text"
|
||||
_DEFAULT_PROMPT = (
|
||||
"Recognize the text in the image and output in Markdown format. "
|
||||
"Preserve the original layout (headings/paragraphs/tables/formulas). "
|
||||
"Do not fabricate content that does not exist in the image."
|
||||
)
|
||||
|
||||
|
||||
class GLMOCREndToEndService(OCRServiceBase):
|
||||
@@ -866,19 +878,21 @@ class GLMOCREndToEndService(OCRServiceBase):
|
||||
Returns:
|
||||
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
|
||||
"""
|
||||
# 1. Layout detection
|
||||
img_h, img_w = image.shape[:2]
|
||||
padded_image = self.image_processor.add_padding(image)
|
||||
layout_info = self.layout_detector.detect(padded_image)
|
||||
# 1. Padding
|
||||
padded = self.image_processor.add_padding(image)
|
||||
img_h, img_w = padded.shape[:2]
|
||||
|
||||
# 2. Layout detection
|
||||
layout_info = self.layout_detector.detect(padded)
|
||||
|
||||
# Sort regions in reading order: top-to-bottom, left-to-right
|
||||
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
|
||||
|
||||
# 3. OCR: per-region (parallel) or full-image fallback
|
||||
if not layout_info.regions or (len(layout_info.regions) == 1 and not layout_info.MixedRecognition):
|
||||
if not layout_info.regions:
|
||||
# No layout detected → assume it's a formula, use formula recognition
|
||||
logger.info("No layout regions detected, treating image as formula")
|
||||
raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"])
|
||||
raw_content = self._call_vllm(padded, _TASK_PROMPTS["formula"])
|
||||
# Format as display formula markdown
|
||||
formatted_content = raw_content.strip()
|
||||
if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")):
|
||||
@@ -891,7 +905,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
||||
if region.type == "figure":
|
||||
continue
|
||||
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
||||
cropped = padded_image[y1:y2, x1:x2]
|
||||
cropped = padded[y1:y2, x1:x2]
|
||||
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
|
||||
logger.warning(
|
||||
"Skipping region idx=%d (label=%s): crop too small %s",
|
||||
@@ -904,13 +918,16 @@ class GLMOCREndToEndService(OCRServiceBase):
|
||||
tasks.append((idx, region, cropped, prompt))
|
||||
|
||||
if not tasks:
|
||||
raw_content = self._call_vllm(image, _DEFAULT_PROMPT)
|
||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
||||
markdown_content = self._formatter._clean_content(raw_content)
|
||||
else:
|
||||
# Parallel OCR calls
|
||||
raw_results: dict[int, str] = {}
|
||||
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
|
||||
future_map = {ex.submit(self._call_vllm, cropped, prompt): idx for idx, region, cropped, prompt in tasks}
|
||||
future_map = {
|
||||
ex.submit(self._call_vllm, cropped, prompt): idx
|
||||
for idx, region, cropped, prompt in tasks
|
||||
}
|
||||
for future in as_completed(future_map):
|
||||
idx = future_map[future]
|
||||
try:
|
||||
@@ -948,3 +965,17 @@ class GLMOCREndToEndService(OCRServiceBase):
|
||||
logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e)
|
||||
|
||||
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mineru_service = MineruOCRService()
|
||||
image = cv2.imread("test/formula2.jpg")
|
||||
image_numpy = np.array(image)
|
||||
# Encode image to bytes (as done in API layer)
|
||||
success, encoded_image = cv2.imencode(".png", image_numpy)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image")
|
||||
image_bytes = BytesIO(encoded_image.tobytes())
|
||||
image_bytes.seek(0)
|
||||
ocr_result = mineru_service.recognize(image_bytes)
|
||||
print(ocr_result)
|
||||
|
||||
@@ -11,7 +11,7 @@ authors = [
|
||||
dependencies = [
|
||||
"fastapi==0.128.0",
|
||||
"uvicorn[standard]==0.40.0",
|
||||
"opencv-python==4.12.0.88",
|
||||
"opencv-python-headless==4.12.0.88", # headless: no Qt/FFmpeg GUI, server-only
|
||||
"python-multipart==0.0.21",
|
||||
"pydantic==2.12.5",
|
||||
"pydantic-settings==2.12.0",
|
||||
@@ -20,7 +20,6 @@ dependencies = [
|
||||
"pillow==12.0.0",
|
||||
"python-docx==1.2.0",
|
||||
"paddleocr==3.4.0",
|
||||
"doclayout-yolo==0.0.4",
|
||||
"latex2mathml==3.78.1",
|
||||
"paddle==1.2.0",
|
||||
"pypandoc==1.16.2",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@@ -34,9 +35,7 @@ def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
|
||||
client = _build_client()
|
||||
|
||||
missing = client.post("/ocr", json={})
|
||||
both = client.post(
|
||||
"/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}
|
||||
)
|
||||
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
|
||||
|
||||
assert missing.status_code == 422
|
||||
assert both.status_code == 422
|
||||
|
||||
@@ -57,22 +57,12 @@ def test_merge_formula_numbers_merges_before_and_after_formula():
|
||||
before = formatter._merge_formula_numbers(
|
||||
[
|
||||
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
|
||||
{
|
||||
"index": 1,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y\n$$",
|
||||
},
|
||||
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
]
|
||||
)
|
||||
after = formatter._merge_formula_numbers(
|
||||
[
|
||||
{
|
||||
"index": 0,
|
||||
"label": "formula",
|
||||
"native_label": "display_formula",
|
||||
"content": "$$\nx+y\n$$",
|
||||
},
|
||||
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
||||
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
|
||||
]
|
||||
)
|
||||
|
||||
@@ -23,9 +23,7 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
||||
|
||||
calls = {}
|
||||
|
||||
def fake_apply_layout_postprocess(
|
||||
boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode
|
||||
):
|
||||
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
|
||||
calls["args"] = {
|
||||
"boxes": boxes,
|
||||
"img_size": img_size,
|
||||
@@ -35,9 +33,7 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
||||
}
|
||||
return [boxes[0], boxes[2]]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess
|
||||
)
|
||||
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
|
||||
|
||||
image = np.zeros((200, 100, 3), dtype=np.uint8)
|
||||
info = detector.detect(image)
|
||||
|
||||
@@ -146,4 +146,6 @@ def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image()
|
||||
layout_merge_bboxes_mode=None,
|
||||
)
|
||||
|
||||
assert result == [{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}]
|
||||
assert result == [
|
||||
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
|
||||
]
|
||||
|
||||
@@ -46,9 +46,7 @@ def test_encode_region_returns_decodable_base64_jpeg():
|
||||
image[:, :] = [0, 128, 255]
|
||||
|
||||
encoded = service._encode_region(image)
|
||||
decoded = cv2.imdecode(
|
||||
np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR
|
||||
)
|
||||
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
assert decoded.shape[:2] == image.shape[:2]
|
||||
|
||||
@@ -73,9 +71,7 @@ def test_call_vllm_builds_messages_and_returns_content():
|
||||
assert captured["model"] == "glm-ocr"
|
||||
assert captured["max_tokens"] == 1024
|
||||
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith(
|
||||
"data:image/jpeg;base64,"
|
||||
)
|
||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||||
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
||||
|
||||
|
||||
@@ -102,19 +98,9 @@ def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
||||
|
||||
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
||||
regions = [
|
||||
LayoutRegion(
|
||||
type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9
|
||||
),
|
||||
LayoutRegion(
|
||||
type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8
|
||||
),
|
||||
LayoutRegion(
|
||||
type="formula",
|
||||
native_label="display_formula",
|
||||
bbox=[20, 20, 40, 40],
|
||||
confidence=0.95,
|
||||
score=0.95,
|
||||
),
|
||||
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
|
||||
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
|
||||
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
|
||||
]
|
||||
service = _build_service(regions=regions)
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def debug_layout_detector():
|
||||
layout_detector = LayoutDetector()
|
||||
image = cv2.imread("test/image2.png")
|
||||
|
||||
print(f"Image shape: {image.shape}")
|
||||
|
||||
# padded_image = ImageProcessor(padding_ratio=0.15).add_padding(image)
|
||||
layout_info = layout_detector.detect(image)
|
||||
|
||||
# draw the layout info and label
|
||||
for region in layout_info.regions:
|
||||
x1, y1, x2, y2 = region.bbox
|
||||
cv2.putText(
|
||||
image,
|
||||
region.native_label,
|
||||
(int(x1), int(y1)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 255),
|
||||
2,
|
||||
)
|
||||
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
|
||||
cv2.imwrite("test/layout_debug.png", image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
debug_layout_detector()
|
||||
Reference in New Issue
Block a user