8 Commits

Author SHA1 Message Date
liuyuanchuang
11e9ed780d Merge branch 'main' of https://code.texpixel.com/YogeLiu/doc_processer 2026-03-12 12:41:43 +08:00
liuyuanchuang
d1050acbdc fix: looger path 2026-03-12 12:41:26 +08:00
16399f0929 fix: logger path 2026-03-12 12:38:18 +08:00
liuyuanchuang
92b56d61d8 feat: add log for export api 2026-03-12 11:40:19 +08:00
bb1cf66137 fix: optimize title to formula 2026-03-10 21:45:43 +08:00
a9d3a35dd7 chore: optimize prompt 2026-03-10 21:36:35 +08:00
d98fa7237c Merge pull request 'fix: remove padding from GLMOCREndToEndService and clean up ruff violations' (#2) from fix/tag into main
Reviewed-on: #2
2026-03-10 19:56:43 +08:00
liuyuanchuang
30d2c2f45b fix: remove padding from GLMOCREndToEndService and clean up ruff violations
- Drop image padding in GLMOCREndToEndService.recognize(); use raw image directly
- Fix F821 undefined `padded` references replaced with `image`
- Fix F601 duplicate dict key "≠" in converter
- Fix F841 unused `image_cls_ids` variable in layout_postprocess
- Fix E702 semicolon-separated statements in layout_postprocess
- Fix UP031 percent-format replaced with f-string in logging_config
- Auto-fix 44 additional ruff violations (import order, UP035/UP045/UP006, F401, F541)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 19:52:22 +08:00
22 changed files with 3112 additions and 167 deletions

View File

@@ -8,7 +8,8 @@
"WebFetch(domain:raw.githubusercontent.com)", "WebFetch(domain:raw.githubusercontent.com)",
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")", "Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)", "Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)" "Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)",
"Bash(ruff check:*)"
] ]
} }
} }

View File

@@ -4,9 +4,12 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import Response from fastapi.responses import Response
from app.core.dependencies import get_converter from app.core.dependencies import get_converter
from app.schemas.convert import MarkdownToDocxRequest, LatexToOmmlRequest, LatexToOmmlResponse from app.core.logging_config import get_logger
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest
from app.services.converter import Converter from app.services.converter import Converter
logger = get_logger()
router = APIRouter() router = APIRouter()
@@ -19,14 +22,25 @@ async def convert_markdown_to_docx(
Returns the generated DOCX file as a binary response. Returns the generated DOCX file as a binary response.
""" """
logger.info(
"Converting markdown to DOCX, filename=%s, content_length=%d",
request.filename,
len(request.markdown),
)
try: try:
docx_bytes = converter.export_to_file(request.markdown, export_type="docx") docx_bytes = converter.export_to_file(request.markdown, export_type="docx")
logger.info(
"DOCX conversion successful, filename=%s, size=%d bytes",
request.filename,
len(docx_bytes),
)
return Response( return Response(
content=docx_bytes, content=docx_bytes,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="{request.filename}.docx"'}, headers={"Content-Disposition": f'attachment; filename="{request.filename}.docx"'},
) )
except Exception as e: except Exception as e:
logger.exception("DOCX conversion failed, filename=%s: %s", request.filename, e)
raise HTTPException(status_code=500, detail=f"Conversion failed: {e}") raise HTTPException(status_code=500, detail=f"Conversion failed: {e}")
@@ -55,12 +69,17 @@ async def convert_latex_to_omml(
``` ```
""" """
if not request.latex or not request.latex.strip(): if not request.latex or not request.latex.strip():
logger.warning("LaTeX to OMML request received with empty formula")
raise HTTPException(status_code=400, detail="LaTeX formula cannot be empty") raise HTTPException(status_code=400, detail="LaTeX formula cannot be empty")
logger.info("Converting LaTeX to OMML, latex=%r", request.latex)
try: try:
omml = converter.convert_to_omml(request.latex) omml = converter.convert_to_omml(request.latex)
logger.info("LaTeX to OMML conversion successful")
return LatexToOmmlResponse(omml=omml) return LatexToOmmlResponse(omml=omml)
except ValueError as e: except ValueError as e:
logger.warning("LaTeX to OMML conversion invalid input: %s", e)
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e: except RuntimeError as e:
logger.error("LaTeX to OMML conversion runtime error: %s", e)
raise HTTPException(status_code=503, detail=str(e)) raise HTTPException(status_code=503, detail=str(e))

View File

@@ -6,10 +6,10 @@ import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from app.core.dependencies import ( from app.core.dependencies import (
get_image_processor,
get_glmocr_endtoend_service, get_glmocr_endtoend_service,
get_image_processor,
) )
from app.core.logging_config import get_logger, RequestIDAdapter from app.core.logging_config import RequestIDAdapter, get_logger
from app.schemas.image import ImageOCRRequest, ImageOCRResponse from app.schemas.image import ImageOCRRequest, ImageOCRResponse
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.ocr_service import GLMOCREndToEndService from app.services.ocr_service import GLMOCREndToEndService

View File

@@ -1,10 +1,10 @@
"""Application dependencies.""" """Application dependencies."""
from app.core.config import get_settings
from app.services.converter import Converter
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import GLMOCREndToEndService from app.services.ocr_service import GLMOCREndToEndService
from app.services.converter import Converter
from app.core.config import get_settings
# Global instances (initialized on startup) # Global instances (initialized on startup)
_layout_detector: LayoutDetector | None = None _layout_detector: LayoutDetector | None = None

View File

@@ -2,11 +2,15 @@
import logging import logging
import logging.handlers import logging.handlers
from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any
from app.core.config import get_settings from app.core.config import get_settings
# Context variable to hold the current request_id across async boundaries
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler): class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler):
"""File handler that rotates by both time (daily) and size (100MB).""" """File handler that rotates by both time (daily) and size (100MB)."""
@@ -18,10 +22,10 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
interval: int = 1, interval: int = 1,
backupCount: int = 30, backupCount: int = 30,
maxBytes: int = 100 * 1024 * 1024, # 100MB maxBytes: int = 100 * 1024 * 1024, # 100MB
encoding: Optional[str] = None, encoding: str | None = None,
delay: bool = False, delay: bool = False,
utc: bool = False, utc: bool = False,
atTime: Optional[Any] = None, atTime: Any | None = None,
): ):
"""Initialize handler with both time and size rotation. """Initialize handler with both time and size rotation.
@@ -58,14 +62,14 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
if self.stream is None: if self.stream is None:
self.stream = self._open() self.stream = self._open()
if self.maxBytes > 0: if self.maxBytes > 0:
msg = "%s\n" % self.format(record) msg = f"{self.format(record)}\n"
self.stream.seek(0, 2) # Seek to end self.stream.seek(0, 2) # Seek to end
if self.stream.tell() + len(msg) >= self.maxBytes: if self.stream.tell() + len(msg) >= self.maxBytes:
return True return True
return False return False
def setup_logging(log_dir: Optional[str] = None) -> logging.Logger: def setup_logging(log_dir: str | None = None) -> logging.Logger:
"""Setup application logging with rotation by day and size. """Setup application logging with rotation by day and size.
Args: Args:
@@ -92,14 +96,13 @@ def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
# Remove existing handlers to avoid duplicates # Remove existing handlers to avoid duplicates
logger.handlers.clear() logger.handlers.clear()
# Create custom formatter that handles missing request_id # Create custom formatter that automatically injects request_id from context
class RequestIDFormatter(logging.Formatter): class RequestIDFormatter(logging.Formatter):
"""Formatter that handles request_id in log records.""" """Formatter that injects request_id from ContextVar into log records."""
def format(self, record): def format(self, record):
# Add request_id if not present
if not hasattr(record, "request_id"): if not hasattr(record, "request_id"):
record.request_id = getattr(record, "request_id", "unknown") record.request_id = request_id_ctx.get()
return super().format(record) return super().format(record)
formatter = RequestIDFormatter( formatter = RequestIDFormatter(
@@ -134,11 +137,11 @@ def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
# Global logger instance # Global logger instance
_logger: Optional[logging.Logger] = None _logger: logging.Logger | None = None
def get_logger() -> logging.Logger: def get_logger() -> logging.Logger:
"""Get the global logger instance.""" """Get the global logger instance, initializing if needed."""
global _logger global _logger
if _logger is None: if _logger is None:
_logger = setup_logging() _logger = setup_logging()

View File

@@ -8,6 +8,7 @@ from app.api.v1.router import api_router
from app.core.config import get_settings from app.core.config import get_settings
from app.core.dependencies import init_layout_detector from app.core.dependencies import init_layout_detector
from app.core.logging_config import setup_logging from app.core.logging_config import setup_logging
from app.middleware.request_id import RequestIDMiddleware
settings = get_settings() settings = get_settings()
@@ -33,6 +34,8 @@ app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
) )
app.add_middleware(RequestIDMiddleware)
# Include API router # Include API router
app.include_router(api_router, prefix=settings.api_prefix) app.include_router(api_router, prefix=settings.api_prefix)

View File

View File

@@ -0,0 +1,34 @@
"""Middleware to propagate or generate request_id for every request."""
import uuid
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from app.core.logging_config import request_id_ctx
REQUEST_ID_HEADER = "X-Request-ID"
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Extract X-Request-ID from incoming request headers or generate one.
The request_id is stored in a ContextVar so that all log records emitted
during the request are automatically annotated with it, without needing to
pass it explicitly through every call.
The same request_id is also echoed back in the response header so that
callers can correlate logs.
"""
async def dispatch(self, request: Request, call_next) -> Response:
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
token = request_id_ctx.set(request_id)
try:
response = await call_next(request)
finally:
request_id_ctx.reset(token)
response.headers[REQUEST_ID_HEADER] = request_id
return response

View File

@@ -36,4 +36,3 @@ class LatexToOmmlResponse(BaseModel):
"""Response body for LaTeX to OMML conversion endpoint.""" """Response body for LaTeX to OMML conversion endpoint."""
omml: str = Field("", description="OMML (Office Math Markup Language) representation") omml: str = Field("", description="OMML (Office Math Markup Language) representation")

View File

@@ -7,7 +7,9 @@ class LayoutRegion(BaseModel):
"""A detected layout region in the document.""" """A detected layout region in the document."""
type: str = Field(..., description="Region type: text, formula, table, figure") type: str = Field(..., description="Region type: text, formula, table, figure")
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)") native_label: str = Field(
"", description="Raw label before type mapping (e.g. doc_title, formula_number)"
)
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]") bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
confidence: float = Field(..., description="Detection confidence score") confidence: float = Field(..., description="Detection confidence score")
score: float = Field(..., description="Detection score") score: float = Field(..., description="Detection score")
@@ -41,10 +43,15 @@ class ImageOCRRequest(BaseModel):
class ImageOCRResponse(BaseModel): class ImageOCRResponse(BaseModel):
"""Response body for image OCR endpoint.""" """Response body for image OCR endpoint."""
latex: str = Field("", description="LaTeX representation of the content (empty if mixed content)") latex: str = Field(
"", description="LaTeX representation of the content (empty if mixed content)"
)
markdown: str = Field("", description="Markdown representation of the content") markdown: str = Field("", description="Markdown representation of the content")
mathml: str = Field("", description="Standard MathML representation (empty if mixed content)") mathml: str = Field("", description="Standard MathML representation (empty if mixed content)")
mml: str = Field("", description="XML MathML with mml: namespace prefix (empty if mixed content)") mml: str = Field(
"", description="XML MathML with mml: namespace prefix (empty if mixed content)"
)
layout_info: LayoutInfo = Field(default_factory=LayoutInfo) layout_info: LayoutInfo = Field(default_factory=LayoutInfo)
recognition_mode: str = Field("", description="Recognition mode used: mixed_recognition or formula_recognition") recognition_mode: str = Field(
"", description="Recognition mode used: mixed_recognition or formula_recognition"
)

View File

@@ -112,14 +112,18 @@ class Converter:
# Pre-compiled regex patterns for preprocessing # Pre-compiled regex patterns for preprocessing
_RE_VSPACE = re.compile(r"\\\[1mm\]") _RE_VSPACE = re.compile(r"\\\[1mm\]")
_RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL) _RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL)
_RE_BLOCK_FORMULA_LINE = re.compile(r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL) _RE_BLOCK_FORMULA_LINE = re.compile(
r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL
)
_RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>') _RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>')
_RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)") _RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)")
_RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}") _RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}")
_RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+") _RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+")
_RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}") _RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}")
_RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL) _RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL)
_RE_ALIGNED_BRACE = re.compile(r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL) _RE_ALIGNED_BRACE = re.compile(
r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL
)
_RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL) _RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL)
_RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL) _RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL)
_RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL) _RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL)
@@ -368,7 +372,9 @@ class Converter:
mathml = latex_to_mathml(latex_formula) mathml = latex_to_mathml(latex_formula)
return Converter._postprocess_mathml_for_word(mathml) return Converter._postprocess_mathml_for_word(mathml)
except Exception as e: except Exception as e:
raise RuntimeError(f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}") from e raise RuntimeError(
f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}"
) from e
@staticmethod @staticmethod
def _postprocess_mathml_for_word(mathml: str) -> str: def _postprocess_mathml_for_word(mathml: str) -> str:
@@ -583,7 +589,6 @@ class Converter:
"&#x21D3;": "", # Downarrow "&#x21D3;": "", # Downarrow
"&#x2195;": "", # updownarrow "&#x2195;": "", # updownarrow
"&#x21D5;": "", # Updownarrow "&#x21D5;": "", # Updownarrow
"&#x2260;": "", # ne
"&#x226A;": "", # ll "&#x226A;": "", # ll
"&#x226B;": "", # gg "&#x226B;": "", # gg
"&#x2A7D;": "", # leqslant "&#x2A7D;": "", # leqslant
@@ -962,7 +967,7 @@ class Converter:
"""Export to DOCX format using pypandoc.""" """Export to DOCX format using pypandoc."""
extra_args = [ extra_args = [
"--highlight-style=pygments", "--highlight-style=pygments",
f"--reference-doc=app/pkg/reference.docx", "--reference-doc=app/pkg/reference.docx",
] ]
pypandoc.convert_file( pypandoc.convert_file(
input_path, input_path,

View File

@@ -1,26 +1,10 @@
"""GLM-OCR postprocessing logic adapted for this project.
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
glm-ocr/glmocr/utils/result_postprocess_utils.py.
Covers:
- Repeated-content / hallucination detection
- Per-region content cleaning and formatting (titles, bullets, formulas)
- formula_number merging (→ \\tag{})
- Hyphenated text-block merging (via wordfreq)
- Missing bullet-point detection
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
import json
logger = logging.getLogger(__name__)
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple from typing import Any
try: try:
from wordfreq import zipf_frequency from wordfreq import zipf_frequency
@@ -29,13 +13,14 @@ try:
except ImportError: except ImportError:
_WORDFREQ_AVAILABLE = False _WORDFREQ_AVAILABLE = False
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# result_postprocess_utils (ported) # result_postprocess_utils (ported)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]: def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None:
"""Detect and truncate a consecutively-repeated pattern. """Detect and truncate a consecutively-repeated pattern.
Returns the string with the repeat removed, or None if not found. Returns the string with the repeat removed, or None if not found.
@@ -106,7 +91,7 @@ def clean_formula_number(number_content: str) -> str:
# Strip display math delimiters # Strip display math delimiters
for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]: for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]:
if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end): if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end):
s = s[len(start):-len(end)].strip() s = s[len(start) : -len(end)].strip()
break break
# Strip CJK/ASCII parentheses # Strip CJK/ASCII parentheses
if s.startswith("(") and s.endswith(")"): if s.startswith("(") and s.endswith(")"):
@@ -120,8 +105,13 @@ def clean_formula_number(number_content: str) -> str:
# GLMResultFormatter # GLMResultFormatter
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Matches content that consists *entirely* of a display-math block and nothing else.
# Used to detect when a text/heading region was actually recognised as a formula by vLLM,
# so we can correct the label before heading prefixes (## …) are applied.
_PURE_DISPLAY_FORMULA_RE = re.compile(r"^\s*(?:\$\$[\s\S]+?\$\$|\\\[[\s\S]+?\\\])\s*$")
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping) # Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
_LABEL_TO_CATEGORY: Dict[str, str] = { _LABEL_TO_CATEGORY: dict[str, str] = {
# text # text
"abstract": "text", "abstract": "text",
"algorithm": "text", "algorithm": "text",
@@ -157,7 +147,7 @@ class GLMResultFormatter:
# Public entry-point # Public entry-point
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def process(self, regions: List[Dict[str, Any]]) -> str: def process(self, regions: list[dict[str, Any]]) -> str:
"""Run the full postprocessing pipeline and return Markdown. """Run the full postprocessing pipeline and return Markdown.
Args: Args:
@@ -175,11 +165,24 @@ class GLMResultFormatter:
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0)) items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
# Per-region cleaning + formatting # Per-region cleaning + formatting
processed: List[Dict] = [] processed: list[dict] = []
for item in items: for item in items:
item["native_label"] = item.get("native_label", item.get("label", "text")) item["native_label"] = item.get("native_label", item.get("label", "text"))
item["label"] = self._map_label(item.get("label", "text"), item["native_label"]) item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
# Label correction: layout may say "text" (or a heading like "paragraph_title")
# but vLLM recognised the content as a formula and returned $$…$$. Without
# correction the heading prefix (##) would be prepended to the math block,
# producing broken output like "## $$ \mathbf{y}=… $$".
raw_content = (item.get("content") or "").strip()
if item["label"] == "text" and _PURE_DISPLAY_FORMULA_RE.match(raw_content):
logger.debug(
"Label corrected text (native=%s) → formula: pure display-formula detected",
item["native_label"],
)
item["label"] = "formula"
item["native_label"] = "display_formula"
item["content"] = self._format_content( item["content"] = self._format_content(
item.get("content") or "", item.get("content") or "",
item["label"], item["label"],
@@ -199,7 +202,7 @@ class GLMResultFormatter:
processed = self._format_bullet_points(processed) processed = self._format_bullet_points(processed)
# Assemble Markdown # Assemble Markdown
parts: List[str] = [] parts: list[str] = []
for item in processed: for item in processed:
content = item.get("content") or "" content = item.get("content") or ""
if item["label"] == "image": if item["label"] == "image":
@@ -263,8 +266,10 @@ class GLMResultFormatter:
if label == "formula": if label == "formula":
content = content.strip() content = content.strip()
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]: for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
if content.startswith(s) and content.endswith(e): if content.startswith(s):
content = content[len(s) : -len(e)].strip() content = content[len(s) :].strip()
if content.endswith(e):
content = content[: -len(e)].strip()
break break
if not content: if not content:
logger.warning("Skipping formula region with empty content after stripping delimiters") logger.warning("Skipping formula region with empty content after stripping delimiters")
@@ -296,12 +301,12 @@ class GLMResultFormatter:
# Structural merges # Structural merges
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]: def _merge_formula_numbers(self, items: list[dict]) -> list[dict]:
"""Merge formula_number region into adjacent formula with \\tag{}.""" """Merge formula_number region into adjacent formula with \\tag{}."""
if not items: if not items:
return items return items
merged: List[Dict] = [] merged: list[dict] = []
skip: set = set() skip: set = set()
for i, block in enumerate(items): for i, block in enumerate(items):
@@ -340,12 +345,12 @@ class GLMResultFormatter:
block["index"] = i block["index"] = i
return merged return merged
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]: def _merge_text_blocks(self, items: list[dict]) -> list[dict]:
"""Merge hyphenated text blocks when the combined word is valid (wordfreq).""" """Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
if not items or not _WORDFREQ_AVAILABLE: if not items or not _WORDFREQ_AVAILABLE:
return items return items
merged: List[Dict] = [] merged: list[dict] = []
skip: set = set() skip: set = set()
for i, block in enumerate(items): for i, block in enumerate(items):
@@ -389,7 +394,7 @@ class GLMResultFormatter:
block["index"] = i block["index"] = i
return merged return merged
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]: def _format_bullet_points(self, items: list[dict], left_align_threshold: float = 10.0) -> list[dict]:
"""Add missing bullet prefix when a text block is sandwiched between two bullet items.""" """Add missing bullet prefix when a text block is sandwiched between two bullet items."""
if len(items) < 3: if len(items) < 3:
return items return items
@@ -419,10 +424,7 @@ class GLMResultFormatter:
if not (cur_bbox and prev_bbox and nxt_bbox): if not (cur_bbox and prev_bbox and nxt_bbox):
continue continue
if ( if abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold:
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
):
cur["content"] = "- " + cur_content cur["content"] = "- " + cur_content
return items return items

View File

@@ -1,12 +1,11 @@
"""PP-DocLayoutV3 wrapper for document layout detection.""" """PP-DocLayoutV3 wrapper for document layout detection."""
import numpy as np import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
from app.core.config import get_settings
from app.services.layout_postprocess import apply_layout_postprocess
from paddleocr import LayoutDetection from paddleocr import LayoutDetection
from typing import Optional
from app.core.config import get_settings
from app.schemas.image import LayoutInfo, LayoutRegion
from app.services.layout_postprocess import apply_layout_postprocess
settings = get_settings() settings = get_settings()
@@ -14,7 +13,7 @@ settings = get_settings()
class LayoutDetector: class LayoutDetector:
"""Layout detector for PP-DocLayoutV2.""" """Layout detector for PP-DocLayoutV2."""
_layout_detector: Optional[LayoutDetection] = None _layout_detector: LayoutDetection | None = None
# PP-DocLayoutV2 class ID to label mapping # PP-DocLayoutV2 class ID to label mapping
CLS_ID_TO_LABEL: dict[int, str] = { CLS_ID_TO_LABEL: dict[int, str] = {
@@ -156,10 +155,11 @@ class LayoutDetector:
if __name__ == "__main__": if __name__ == "__main__":
import cv2 import cv2
from app.core.config import get_settings from app.core.config import get_settings
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter from app.services.converter import Converter
from app.services.ocr_service import OCRService from app.services.image_processor import ImageProcessor
from app.services.ocr_service import GLMOCREndToEndService
settings = get_settings() settings = get_settings()
@@ -169,15 +169,15 @@ if __name__ == "__main__":
converter = Converter() converter = Converter()
# Initialize OCR service # Initialize OCR service
ocr_service = OCRService( ocr_service = GLMOCREndToEndService(
vl_server_url=settings.paddleocr_vl_url, vl_server_url=settings.glm_ocr_url,
layout_detector=layout_detector, layout_detector=layout_detector,
image_processor=image_processor, image_processor=image_processor,
converter=converter, converter=converter,
) )
# Load test image # Load test image
image_path = "test/timeout.jpg" image_path = "test/image2.png"
image = cv2.imread(image_path) image = cv2.imread(image_path)
if image is None: if image is None:

View File

@@ -15,16 +15,14 @@ the quality of the GLM-OCR SDK's layout pipeline.
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Primitive geometry helpers # Primitive geometry helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def iou(box1: List[float], box2: List[float]) -> float:
def iou(box1: list[float], box2: list[float]) -> float:
"""Compute IoU of two bounding boxes [x1, y1, x2, y2].""" """Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
x1, y1, x2, y2 = box1 x1, y1, x2, y2 = box1
x1_p, y1_p, x2_p, y2_p = box2 x1_p, y1_p, x2_p, y2_p = box2
@@ -41,7 +39,7 @@ def iou(box1: List[float], box2: List[float]) -> float:
return inter_area / float(box1_area + box2_area - inter_area) return inter_area / float(box1_area + box2_area - inter_area)
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool: def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool:
"""Return True if box1 is contained within box2 (overlap ratio >= threshold). """Return True if box1 is contained within box2 (overlap ratio >= threshold).
box format: [cls_id, score, x1, y1, x2, y2] box format: [cls_id, score, x1, y1, x2, y2]
@@ -66,11 +64,12 @@ def is_contained(box1: List[float], box2: List[float], overlap_threshold: float
# NMS # NMS
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def nms( def nms(
boxes: np.ndarray, boxes: np.ndarray,
iou_same: float = 0.6, iou_same: float = 0.6,
iou_diff: float = 0.98, iou_diff: float = 0.98,
) -> List[int]: ) -> list[int]:
"""NMS with separate IoU thresholds for same-class and cross-class overlaps. """NMS with separate IoU thresholds for same-class and cross-class overlaps.
Args: Args:
@@ -83,7 +82,7 @@ def nms(
""" """
scores = boxes[:, 1] scores = boxes[:, 1]
indices = np.argsort(scores)[::-1].tolist() indices = np.argsort(scores)[::-1].tolist()
selected: List[int] = [] selected: list[int] = []
while indices: while indices:
current = indices[0] current = indices[0]
@@ -114,10 +113,10 @@ _PRESERVE_LABELS = {"image", "seal", "chart"}
def check_containment( def check_containment(
boxes: np.ndarray, boxes: np.ndarray,
preserve_cls_ids: Optional[set] = None, preserve_cls_ids: set | None = None,
category_index: Optional[int] = None, category_index: int | None = None,
mode: Optional[str] = None, mode: str | None = None,
) -> Tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
"""Compute containment flags for each box. """Compute containment flags for each box.
Args: Args:
@@ -160,9 +159,10 @@ def check_containment(
# Box expansion (unclip) # Box expansion (unclip)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def unclip_boxes( def unclip_boxes(
boxes: np.ndarray, boxes: np.ndarray,
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None], unclip_ratio: float | tuple[float, float] | dict | list | None,
) -> np.ndarray: ) -> np.ndarray:
"""Expand bounding boxes by the given ratio. """Expand bounding boxes by the given ratio.
@@ -215,13 +215,14 @@ def unclip_boxes(
# Main entry-point # Main entry-point
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def apply_layout_postprocess( def apply_layout_postprocess(
boxes: List[Dict], boxes: list[dict],
img_size: Tuple[int, int], img_size: tuple[int, int],
layout_nms: bool = True, layout_nms: bool = True,
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None, layout_unclip_ratio: float | tuple | dict | None = None,
layout_merge_bboxes_mode: Union[str, Dict, None] = "large", layout_merge_bboxes_mode: str | dict | None = "large",
) -> List[Dict]: ) -> list[dict]:
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results. """Apply GLM-OCR layout post-processing to PaddleOCR detection results.
Args: Args:
@@ -250,7 +251,7 @@ def apply_layout_postprocess(
arr_rows.append([cls_id, score, x1, y1, x2, y2]) arr_rows.append([cls_id, score, x1, y1, x2, y2])
boxes_array = np.array(arr_rows, dtype=float) boxes_array = np.array(arr_rows, dtype=float)
all_labels: List[str] = [b.get("label", "") for b in boxes] all_labels: list[str] = [b.get("label", "") for b in boxes]
# 1. NMS ---------------------------------------------------------------- # # 1. NMS ---------------------------------------------------------------- #
if layout_nms and len(boxes_array) > 1: if layout_nms and len(boxes_array) > 1:
@@ -262,17 +263,14 @@ def apply_layout_postprocess(
if len(boxes_array) > 1: if len(boxes_array) > 1:
img_area = img_width * img_height img_area = img_width * img_height
area_thres = 0.82 if img_width > img_height else 0.93 area_thres = 0.82 if img_width > img_height else 0.93
image_cls_ids = {
int(boxes_array[i, 0])
for i, lbl in enumerate(all_labels)
if lbl == "image"
}
keep_mask = np.ones(len(boxes_array), dtype=bool) keep_mask = np.ones(len(boxes_array), dtype=bool)
for i, lbl in enumerate(all_labels): for i, lbl in enumerate(all_labels):
if lbl == "image": if lbl == "image":
x1, y1, x2, y2 = boxes_array[i, 2:6] x1, y1, x2, y2 = boxes_array[i, 2:6]
x1 = max(0.0, x1); y1 = max(0.0, y1) x1 = max(0.0, x1)
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2) y1 = max(0.0, y1)
x2 = min(float(img_width), x2)
y2 = min(float(img_height), y2)
if (x2 - x1) * (y2 - y1) > area_thres * img_area: if (x2 - x1) * (y2 - y1) > area_thres * img_area:
keep_mask[i] = False keep_mask[i] = False
boxes_array = boxes_array[keep_mask] boxes_array = boxes_array[keep_mask]
@@ -281,9 +279,7 @@ def apply_layout_postprocess(
# 3. Containment analysis (merge_bboxes_mode) -------------------------- # # 3. Containment analysis (merge_bboxes_mode) -------------------------- #
if layout_merge_bboxes_mode and len(boxes_array) > 1: if layout_merge_bboxes_mode and len(boxes_array) > 1:
preserve_cls_ids = { preserve_cls_ids = {
int(boxes_array[i, 0]) int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS
for i, lbl in enumerate(all_labels)
if lbl in _PRESERVE_LABELS
} }
if isinstance(layout_merge_bboxes_mode, str): if isinstance(layout_merge_bboxes_mode, str):
@@ -321,7 +317,7 @@ def apply_layout_postprocess(
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio) boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
# 5. Clamp to image boundaries + skip invalid -------------------------- # # 5. Clamp to image boundaries + skip invalid -------------------------- #
result: List[Dict] = [] result: list[dict] = []
for i, row in enumerate(boxes_array): for i, row in enumerate(boxes_array):
cls_id = int(row[0]) cls_id = int(row[0])
score = float(row[1]) score = float(row[1])
@@ -333,11 +329,13 @@ def apply_layout_postprocess(
if x1 >= x2 or y1 >= y2: if x1 >= x2 or y1 >= y2:
continue continue
result.append({ result.append(
"cls_id": cls_id, {
"label": all_labels[i], "cls_id": cls_id,
"score": score, "label": all_labels[i],
"coordinate": [int(x1), int(y1), int(x2), int(y2)], "score": score,
}) "coordinate": [int(x1), int(y1), int(x2), int(y2)],
}
)
return result return result

View File

@@ -150,9 +150,7 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
# Strategy: remove spaces before \ and between non-command chars, # Strategy: remove spaces before \ and between non-command chars,
# but preserve the space after \command when followed by a non-\ char # but preserve the space after \command when followed by a non-\ char
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
cleaned = re.sub( cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
) # remove space after non-letter non-\
return f"{operator}{{{cleaned}}}" return f"{operator}{{{cleaned}}}"
# Match _{ ... } or ^{ ... } # Match _{ ... } or ^{ ... }
@@ -630,9 +628,7 @@ class MineruOCRService(OCRServiceBase):
self.glm_ocr_url = glm_ocr_url self.glm_ocr_url = glm_ocr_url
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600) self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
def _recognize_formula_with_paddleocr_vl( def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
self, image: np.ndarray, prompt: str = "Formula Recognition:"
) -> str:
"""Recognize formula using PaddleOCR-VL API. """Recognize formula using PaddleOCR-VL API.
Args: Args:
@@ -673,9 +669,7 @@ class MineruOCRService(OCRServiceBase):
except Exception as e: except Exception as e:
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
def _extract_and_recognize_formulas( def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
self, markdown_content: str, original_image: np.ndarray
) -> str:
"""Extract image references from markdown and recognize formulas. """Extract image references from markdown and recognize formulas.
Args: Args:
@@ -757,9 +751,7 @@ class MineruOCRService(OCRServiceBase):
markdown_content = result["results"]["image"].get("md_content", "") markdown_content = result["results"]["image"].get("md_content", "")
if "![](images/" in markdown_content: if "![](images/" in markdown_content:
markdown_content = self._extract_and_recognize_formulas( markdown_content = self._extract_and_recognize_formulas(markdown_content, original_image)
markdown_content, original_image
)
# Apply postprocessing to fix OCR errors # Apply postprocessing to fix OCR errors
markdown_content = _postprocess_markdown(markdown_content) markdown_content = _postprocess_markdown(markdown_content)
@@ -789,15 +781,11 @@ class MineruOCRService(OCRServiceBase):
# Task-specific prompts (from GLM-OCR SDK config.yaml) # Task-specific prompts (from GLM-OCR SDK config.yaml)
_TASK_PROMPTS: dict[str, str] = { _TASK_PROMPTS: dict[str, str] = {
"text": "Text Recognition:", "text": "Text Recognition. If the content is a formula, please ouput latex code, else output text",
"formula": "Formula Recognition:", "formula": "Formula Recognition:",
"table": "Table Recognition:", "table": "Table Recognition:",
} }
_DEFAULT_PROMPT = ( _DEFAULT_PROMPT = "Text Recognition. If the content is a formula, please ouput latex code, else output text"
"Recognize the text in the image and output in Markdown format. "
"Preserve the original layout (headings/paragraphs/tables/formulas). "
"Do not fabricate content that does not exist in the image."
)
class GLMOCREndToEndService(OCRServiceBase): class GLMOCREndToEndService(OCRServiceBase):
@@ -878,12 +866,9 @@ class GLMOCREndToEndService(OCRServiceBase):
Returns: Returns:
Dict with 'markdown', 'latex', 'mathml', 'mml' keys. Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
""" """
# 1. Padding # 1. Layout detection
padded = self.image_processor.add_padding(image) img_h, img_w = image.shape[:2]
img_h, img_w = padded.shape[:2] layout_info = self.layout_detector.detect(image)
# 2. Layout detection
layout_info = self.layout_detector.detect(padded)
# Sort regions in reading order: top-to-bottom, left-to-right # Sort regions in reading order: top-to-bottom, left-to-right
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0])) layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
@@ -892,7 +877,7 @@ class GLMOCREndToEndService(OCRServiceBase):
if not layout_info.regions: if not layout_info.regions:
# No layout detected → assume it's a formula, use formula recognition # No layout detected → assume it's a formula, use formula recognition
logger.info("No layout regions detected, treating image as formula") logger.info("No layout regions detected, treating image as formula")
raw_content = self._call_vllm(padded, _TASK_PROMPTS["formula"]) raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"])
# Format as display formula markdown # Format as display formula markdown
formatted_content = raw_content.strip() formatted_content = raw_content.strip()
if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")): if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")):
@@ -905,7 +890,7 @@ class GLMOCREndToEndService(OCRServiceBase):
if region.type == "figure": if region.type == "figure":
continue continue
x1, y1, x2, y2 = (int(c) for c in region.bbox) x1, y1, x2, y2 = (int(c) for c in region.bbox)
cropped = padded[y1:y2, x1:x2] cropped = image[y1:y2, x1:x2]
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10: if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
logger.warning( logger.warning(
"Skipping region idx=%d (label=%s): crop too small %s", "Skipping region idx=%d (label=%s): crop too small %s",
@@ -918,16 +903,13 @@ class GLMOCREndToEndService(OCRServiceBase):
tasks.append((idx, region, cropped, prompt)) tasks.append((idx, region, cropped, prompt))
if not tasks: if not tasks:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT) raw_content = self._call_vllm(image, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content) markdown_content = self._formatter._clean_content(raw_content)
else: else:
# Parallel OCR calls # Parallel OCR calls
raw_results: dict[int, str] = {} raw_results: dict[int, str] = {}
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex: with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
future_map = { future_map = {ex.submit(self._call_vllm, cropped, prompt): idx for idx, region, cropped, prompt in tasks}
ex.submit(self._call_vllm, cropped, prompt): idx
for idx, region, cropped, prompt in tasks
}
for future in as_completed(future_map): for future in as_completed(future_map):
idx = future_map[future] idx = future_map[future]
try: try:
@@ -965,17 +947,3 @@ class GLMOCREndToEndService(OCRServiceBase):
logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e) logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e)
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml} return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/formula2.jpg")
image_numpy = np.array(image)
# Encode image to bytes (as done in API layer)
success, encoded_image = cv2.imencode(".png", image_numpy)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
image_bytes.seek(0)
ocr_result = mineru_service.recognize(image_bytes)
print(ocr_result)

2844
nohup.out Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
import numpy as np import numpy as np
import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -35,7 +34,9 @@ def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
client = _build_client() client = _build_client()
missing = client.post("/ocr", json={}) missing = client.post("/ocr", json={})
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}) both = client.post(
"/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}
)
assert missing.status_code == 422 assert missing.status_code == 422
assert both.status_code == 422 assert both.status_code == 422

View File

@@ -57,12 +57,22 @@ def test_merge_formula_numbers_merges_before_and_after_formula():
before = formatter._merge_formula_numbers( before = formatter._merge_formula_numbers(
[ [
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"}, {"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, {
"index": 1,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y\n$$",
},
] ]
) )
after = formatter._merge_formula_numbers( after = formatter._merge_formula_numbers(
[ [
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, {
"index": 0,
"label": "formula",
"native_label": "display_formula",
"content": "$$\nx+y\n$$",
},
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"}, {"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
] ]
) )

View File

@@ -23,7 +23,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
calls = {} calls = {}
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode): def fake_apply_layout_postprocess(
boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode
):
calls["args"] = { calls["args"] = {
"boxes": boxes, "boxes": boxes,
"img_size": img_size, "img_size": img_size,
@@ -33,7 +35,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
} }
return [boxes[0], boxes[2]] return [boxes[0], boxes[2]]
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess) monkeypatch.setattr(
"app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess
)
image = np.zeros((200, 100, 3), dtype=np.uint8) image = np.zeros((200, 100, 3), dtype=np.uint8)
info = detector.detect(image) info = detector.detect(image)

View File

@@ -146,6 +146,4 @@ def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image()
layout_merge_bboxes_mode=None, layout_merge_bboxes_mode=None,
) )
assert result == [ assert result == [{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}]
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
]

View File

@@ -46,7 +46,9 @@ def test_encode_region_returns_decodable_base64_jpeg():
image[:, :] = [0, 128, 255] image[:, :] = [0, 128, 255]
encoded = service._encode_region(image) encoded = service._encode_region(image)
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR) decoded = cv2.imdecode(
np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR
)
assert decoded.shape[:2] == image.shape[:2] assert decoded.shape[:2] == image.shape[:2]
@@ -71,7 +73,9 @@ def test_call_vllm_builds_messages_and_returns_content():
assert captured["model"] == "glm-ocr" assert captured["model"] == "glm-ocr"
assert captured["max_tokens"] == 1024 assert captured["max_tokens"] == 1024
assert captured["messages"][0]["content"][0]["type"] == "image_url" assert captured["messages"][0]["content"][0]["type"] == "image_url"
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,") assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith(
"data:image/jpeg;base64,"
)
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"} assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
@@ -98,9 +102,19 @@ def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch): def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
regions = [ regions = [
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9), LayoutRegion(
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8), type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95), ),
LayoutRegion(
type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8
),
LayoutRegion(
type="formula",
native_label="display_formula",
bbox=[20, 20, 40, 40],
confidence=0.95,
score=0.95,
),
] ]
service = _build_service(regions=regions) service = _build_service(regions=regions)
image = np.zeros((40, 40, 3), dtype=np.uint8) image = np.zeros((40, 40, 3), dtype=np.uint8)

35
tests/tools/layout.py Normal file
View File

@@ -0,0 +1,35 @@
import cv2
from app.core.config import get_settings
from app.services.layout_detector import LayoutDetector
settings = get_settings()
def debug_layout_detector():
layout_detector = LayoutDetector()
image = cv2.imread("test/image2.png")
print(f"Image shape: {image.shape}")
# padded_image = ImageProcessor(padding_ratio=0.15).add_padding(image)
layout_info = layout_detector.detect(image)
# draw the layout info and label
for region in layout_info.regions:
x1, y1, x2, y2 = region.bbox
cv2.putText(
image,
region.native_label,
(int(x1), int(y1)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 255),
2,
)
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
cv2.imwrite("test/layout_debug.png", image)
if __name__ == "__main__":
debug_layout_detector()