Compare commits
9 Commits
d98fa7237c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39e72a5743 | ||
| aee1a1bf3b | |||
| ff82021467 | |||
|
|
11e9ed780d | ||
|
|
d1050acbdc | ||
| 16399f0929 | |||
|
|
92b56d61d8 | ||
| bb1cf66137 | |||
| a9d3a35dd7 |
@@ -8,7 +8,8 @@
|
|||||||
"WebFetch(domain:raw.githubusercontent.com)",
|
"WebFetch(domain:raw.githubusercontent.com)",
|
||||||
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
|
"Bash(python -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(md\\)\n\" 2>&1 | grep -v \"^$\")",
|
||||||
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
|
"Bash(python3 -c \"\nfrom app.services.glm_postprocess import GLMResultFormatter, clean_repeated_content, clean_formula_number\nf = GLMResultFormatter\\(\\)\nprint\\('GLMResultFormatter OK'\\)\nprint\\('clean_formula_number:', clean_formula_number\\('\\(2.1\\)'\\)\\)\nregions = [\n {'index': 0, 'label': 'text', 'native_label': 'doc_title', 'content': 'Introduction', 'bbox_2d': [10,10,990,50]},\n {'index': 1, 'label': 'formula', 'native_label': 'display_formula', 'content': r'\\\\frac{a}{b}', 'bbox_2d': [10,60,990,200]},\n {'index': 2, 'label': 'text', 'native_label': 'formula_number', 'content': '\\(1\\)', 'bbox_2d': [900,60,990,200]},\n]\nmd = f.process\\(regions\\)\nprint\\('process output:'\\)\nprint\\(repr\\(md\\)\\)\n\" 2>&1)",
|
||||||
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)"
|
"Bash(ls .venv 2>/dev/null || ls venv 2>/dev/null || echo \"no venv found\" && find . -name \"activate\" -path \"*/bin/activate\" 2>/dev/null | head -3)",
|
||||||
|
"Bash(ruff check:*)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
"""Format conversion endpoints."""
|
"""Format conversion endpoints."""
|
||||||
|
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from app.core.dependencies import get_converter
|
from app.core.dependencies import get_converter
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest
|
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest
|
||||||
from app.services.converter import Converter
|
from app.services.converter import Converter
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@@ -19,14 +24,26 @@ async def convert_markdown_to_docx(
|
|||||||
|
|
||||||
Returns the generated DOCX file as a binary response.
|
Returns the generated DOCX file as a binary response.
|
||||||
"""
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Converting markdown to DOCX, filename=%s, content_length=%d",
|
||||||
|
request.filename,
|
||||||
|
len(request.markdown),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
docx_bytes = converter.export_to_file(request.markdown, export_type="docx")
|
docx_bytes = converter.export_to_file(request.markdown, export_type="docx")
|
||||||
|
logger.info(
|
||||||
|
"DOCX conversion successful, filename=%s, size=%d bytes",
|
||||||
|
request.filename,
|
||||||
|
len(docx_bytes),
|
||||||
|
)
|
||||||
|
encoded_name = quote(f"{request.filename}.docx")
|
||||||
return Response(
|
return Response(
|
||||||
content=docx_bytes,
|
content=docx_bytes,
|
||||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
headers={"Content-Disposition": f'attachment; filename="{request.filename}.docx"'},
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception("DOCX conversion failed, filename=%s: %s", request.filename, e)
|
||||||
raise HTTPException(status_code=500, detail=f"Conversion failed: {e}")
|
raise HTTPException(status_code=500, detail=f"Conversion failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
@@ -55,12 +72,17 @@ async def convert_latex_to_omml(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if not request.latex or not request.latex.strip():
|
if not request.latex or not request.latex.strip():
|
||||||
|
logger.warning("LaTeX to OMML request received with empty formula")
|
||||||
raise HTTPException(status_code=400, detail="LaTeX formula cannot be empty")
|
raise HTTPException(status_code=400, detail="LaTeX formula cannot be empty")
|
||||||
|
|
||||||
|
logger.info("Converting LaTeX to OMML, latex=%r", request.latex)
|
||||||
try:
|
try:
|
||||||
omml = converter.convert_to_omml(request.latex)
|
omml = converter.convert_to_omml(request.latex)
|
||||||
|
logger.info("LaTeX to OMML conversion successful")
|
||||||
return LatexToOmmlResponse(omml=omml)
|
return LatexToOmmlResponse(omml=omml)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
logger.warning("LaTeX to OMML conversion invalid input: %s", e)
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
logger.error("LaTeX to OMML conversion runtime error: %s", e)
|
||||||
raise HTTPException(status_code=503, detail=str(e))
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ class Settings(BaseSettings):
|
|||||||
max_tokens: int = 4096
|
max_tokens: int = 4096
|
||||||
|
|
||||||
# Model Paths
|
# Model Paths
|
||||||
pp_doclayout_model_dir: str | None = (
|
pp_doclayout_model_dir: str | None = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
||||||
"/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV3"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Image Processing
|
# Image Processing
|
||||||
max_image_size_mb: int = 10
|
max_image_size_mb: int = 10
|
||||||
|
|||||||
@@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
|
from contextvars import ContextVar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
# Context variable to hold the current request_id across async boundaries
|
||||||
|
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
|
||||||
|
|
||||||
|
|
||||||
class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler):
|
class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler):
|
||||||
"""File handler that rotates by both time (daily) and size (100MB)."""
|
"""File handler that rotates by both time (daily) and size (100MB)."""
|
||||||
@@ -92,14 +96,13 @@ def setup_logging(log_dir: str | None = None) -> logging.Logger:
|
|||||||
# Remove existing handlers to avoid duplicates
|
# Remove existing handlers to avoid duplicates
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
|
|
||||||
# Create custom formatter that handles missing request_id
|
# Create custom formatter that automatically injects request_id from context
|
||||||
class RequestIDFormatter(logging.Formatter):
|
class RequestIDFormatter(logging.Formatter):
|
||||||
"""Formatter that handles request_id in log records."""
|
"""Formatter that injects request_id from ContextVar into log records."""
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
# Add request_id if not present
|
|
||||||
if not hasattr(record, "request_id"):
|
if not hasattr(record, "request_id"):
|
||||||
record.request_id = getattr(record, "request_id", "unknown")
|
record.request_id = request_id_ctx.get()
|
||||||
return super().format(record)
|
return super().format(record)
|
||||||
|
|
||||||
formatter = RequestIDFormatter(
|
formatter = RequestIDFormatter(
|
||||||
@@ -138,7 +141,7 @@ _logger: logging.Logger | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_logger() -> logging.Logger:
|
def get_logger() -> logging.Logger:
|
||||||
"""Get the global logger instance."""
|
"""Get the global logger instance, initializing if needed."""
|
||||||
global _logger
|
global _logger
|
||||||
if _logger is None:
|
if _logger is None:
|
||||||
_logger = setup_logging()
|
_logger = setup_logging()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.api.v1.router import api_router
|
|||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.dependencies import init_layout_detector
|
from app.core.dependencies import init_layout_detector
|
||||||
from app.core.logging_config import setup_logging
|
from app.core.logging_config import setup_logging
|
||||||
|
from app.middleware.request_id import RequestIDMiddleware
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
@@ -33,6 +34,8 @@ app = FastAPI(
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.add_middleware(RequestIDMiddleware)
|
||||||
|
|
||||||
# Include API router
|
# Include API router
|
||||||
app.include_router(api_router, prefix=settings.api_prefix)
|
app.include_router(api_router, prefix=settings.api_prefix)
|
||||||
|
|
||||||
|
|||||||
0
app/middleware/__init__.py
Normal file
0
app/middleware/__init__.py
Normal file
34
app/middleware/request_id.py
Normal file
34
app/middleware/request_id.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Middleware to propagate or generate request_id for every request."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
from app.core.logging_config import request_id_ctx
|
||||||
|
|
||||||
|
REQUEST_ID_HEADER = "X-Request-ID"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Extract X-Request-ID from incoming request headers or generate one.
|
||||||
|
|
||||||
|
The request_id is stored in a ContextVar so that all log records emitted
|
||||||
|
during the request are automatically annotated with it, without needing to
|
||||||
|
pass it explicitly through every call.
|
||||||
|
|
||||||
|
The same request_id is also echoed back in the response header so that
|
||||||
|
callers can correlate logs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
|
||||||
|
token = request_id_ctx.set(request_id)
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
finally:
|
||||||
|
request_id_ctx.reset(token)
|
||||||
|
|
||||||
|
response.headers[REQUEST_ID_HEADER] = request_id
|
||||||
|
return response
|
||||||
@@ -34,13 +34,7 @@ def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 1
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
r"(.{"
|
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
|
||||||
+ str(min_unit_len)
|
|
||||||
+ ","
|
|
||||||
+ str(max_unit_len)
|
|
||||||
+ r"}?)\1{"
|
|
||||||
+ str(min_repeats - 1)
|
|
||||||
+ ",}",
|
|
||||||
re.DOTALL,
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
match = pattern.search(s)
|
match = pattern.search(s)
|
||||||
@@ -74,9 +68,7 @@ def clean_repeated_content(
|
|||||||
if count >= line_threshold and (count / total_lines) >= 0.8:
|
if count >= line_threshold and (count / total_lines) >= 0.8:
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
if line == common:
|
if line == common:
|
||||||
consecutive = sum(
|
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
|
||||||
1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common
|
|
||||||
)
|
|
||||||
if consecutive >= 3:
|
if consecutive >= 3:
|
||||||
original_lines = content.split("\n")
|
original_lines = content.split("\n")
|
||||||
non_empty_count = 0
|
non_empty_count = 0
|
||||||
@@ -113,6 +105,11 @@ def clean_formula_number(number_content: str) -> str:
|
|||||||
# GLMResultFormatter
|
# GLMResultFormatter
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Matches content that consists *entirely* of a display-math block and nothing else.
|
||||||
|
# Used to detect when a text/heading region was actually recognised as a formula by vLLM,
|
||||||
|
# so we can correct the label before heading prefixes (## …) are applied.
|
||||||
|
_PURE_DISPLAY_FORMULA_RE = re.compile(r"^\s*(?:\$\$[\s\S]+?\$\$|\\\[[\s\S]+?\\\])\s*$")
|
||||||
|
|
||||||
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
||||||
_LABEL_TO_CATEGORY: dict[str, str] = {
|
_LABEL_TO_CATEGORY: dict[str, str] = {
|
||||||
# text
|
# text
|
||||||
@@ -173,6 +170,19 @@ class GLMResultFormatter:
|
|||||||
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
||||||
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
||||||
|
|
||||||
|
# Label correction: layout may say "text" (or a heading like "paragraph_title")
|
||||||
|
# but vLLM recognised the content as a formula and returned $$…$$. Without
|
||||||
|
# correction the heading prefix (##) would be prepended to the math block,
|
||||||
|
# producing broken output like "## $$ \mathbf{y}=… $$".
|
||||||
|
raw_content = (item.get("content") or "").strip()
|
||||||
|
if item["label"] == "text" and _PURE_DISPLAY_FORMULA_RE.match(raw_content):
|
||||||
|
logger.debug(
|
||||||
|
"Label corrected text (native=%s) → formula: pure display-formula detected",
|
||||||
|
item["native_label"],
|
||||||
|
)
|
||||||
|
item["label"] = "formula"
|
||||||
|
item["native_label"] = "display_formula"
|
||||||
|
|
||||||
item["content"] = self._format_content(
|
item["content"] = self._format_content(
|
||||||
item.get("content") or "",
|
item.get("content") or "",
|
||||||
item["label"],
|
item["label"],
|
||||||
@@ -255,16 +265,14 @@ class GLMResultFormatter:
|
|||||||
# Formula wrapping
|
# Formula wrapping
|
||||||
if label == "formula":
|
if label == "formula":
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)"), ("$", "$")]:
|
||||||
if content.startswith(s):
|
if content.startswith(s):
|
||||||
content = content[len(s) :].strip()
|
content = content[len(s) :].strip()
|
||||||
if content.endswith(e):
|
if content.endswith(e):
|
||||||
content = content[: -len(e)].strip()
|
content = content[: -len(e)].strip()
|
||||||
break
|
break
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(
|
logger.warning("Skipping formula region with empty content after stripping delimiters")
|
||||||
"Skipping formula region with empty content after stripping delimiters"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
content = "$$\n" + content + "\n$$"
|
content = "$$\n" + content + "\n$$"
|
||||||
|
|
||||||
@@ -314,9 +322,7 @@ class GLMResultFormatter:
|
|||||||
formula_content = items[i + 1].get("content", "")
|
formula_content = items[i + 1].get("content", "")
|
||||||
merged_block = deepcopy(items[i + 1])
|
merged_block = deepcopy(items[i + 1])
|
||||||
if formula_content.endswith("\n$$"):
|
if formula_content.endswith("\n$$"):
|
||||||
merged_block["content"] = (
|
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||||
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
|
||||||
)
|
|
||||||
merged.append(merged_block)
|
merged.append(merged_block)
|
||||||
skip.add(i + 1)
|
skip.add(i + 1)
|
||||||
continue # always skip the formula_number block itself
|
continue # always skip the formula_number block itself
|
||||||
@@ -328,9 +334,7 @@ class GLMResultFormatter:
|
|||||||
formula_content = block.get("content", "")
|
formula_content = block.get("content", "")
|
||||||
merged_block = deepcopy(block)
|
merged_block = deepcopy(block)
|
||||||
if formula_content.endswith("\n$$"):
|
if formula_content.endswith("\n$$"):
|
||||||
merged_block["content"] = (
|
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||||
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
|
||||||
)
|
|
||||||
merged.append(merged_block)
|
merged.append(merged_block)
|
||||||
skip.add(i + 1)
|
skip.add(i + 1)
|
||||||
continue
|
continue
|
||||||
@@ -390,9 +394,7 @@ class GLMResultFormatter:
|
|||||||
block["index"] = i
|
block["index"] = i
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
def _format_bullet_points(
|
def _format_bullet_points(self, items: list[dict], left_align_threshold: float = 10.0) -> list[dict]:
|
||||||
self, items: list[dict], left_align_threshold: float = 10.0
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
||||||
if len(items) < 3:
|
if len(items) < 3:
|
||||||
return items
|
return items
|
||||||
@@ -422,10 +424,7 @@ class GLMResultFormatter:
|
|||||||
if not (cur_bbox and prev_bbox and nxt_bbox):
|
if not (cur_bbox and prev_bbox and nxt_bbox):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold:
|
||||||
abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold
|
|
||||||
and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold
|
|
||||||
):
|
|
||||||
cur["content"] = "- " + cur_content
|
cur["content"] = "- " + cur_content
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class LayoutDetector:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed_recognition = any(region.type == "text" and region.score > 0.3 for region in regions)
|
mixed_recognition = any(region.type == "text" and region.score > 0.85 for region in regions)
|
||||||
|
|
||||||
return LayoutInfo(regions=regions, MixedRecognition=mixed_recognition)
|
return LayoutInfo(regions=regions, MixedRecognition=mixed_recognition)
|
||||||
|
|
||||||
|
|||||||
@@ -150,9 +150,7 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
|
|||||||
# Strategy: remove spaces before \ and between non-command chars,
|
# Strategy: remove spaces before \ and between non-command chars,
|
||||||
# but preserve the space after \command when followed by a non-\ char
|
# but preserve the space after \command when followed by a non-\ char
|
||||||
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
|
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
|
||||||
cleaned = re.sub(
|
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
|
||||||
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
|
|
||||||
) # remove space after non-letter non-\
|
|
||||||
return f"{operator}{{{cleaned}}}"
|
return f"{operator}{{{cleaned}}}"
|
||||||
|
|
||||||
# Match _{ ... } or ^{ ... }
|
# Match _{ ... } or ^{ ... }
|
||||||
@@ -630,9 +628,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
self.glm_ocr_url = glm_ocr_url
|
self.glm_ocr_url = glm_ocr_url
|
||||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
|
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
|
||||||
|
|
||||||
def _recognize_formula_with_paddleocr_vl(
|
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
|
||||||
self, image: np.ndarray, prompt: str = "Formula Recognition:"
|
|
||||||
) -> str:
|
|
||||||
"""Recognize formula using PaddleOCR-VL API.
|
"""Recognize formula using PaddleOCR-VL API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -673,9 +669,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
|
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
|
||||||
|
|
||||||
def _extract_and_recognize_formulas(
|
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
|
||||||
self, markdown_content: str, original_image: np.ndarray
|
|
||||||
) -> str:
|
|
||||||
"""Extract image references from markdown and recognize formulas.
|
"""Extract image references from markdown and recognize formulas.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -757,9 +751,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
markdown_content = result["results"]["image"].get("md_content", "")
|
markdown_content = result["results"]["image"].get("md_content", "")
|
||||||
|
|
||||||
if "
|
||||||
markdown_content, original_image
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply postprocessing to fix OCR errors
|
# Apply postprocessing to fix OCR errors
|
||||||
markdown_content = _postprocess_markdown(markdown_content)
|
markdown_content = _postprocess_markdown(markdown_content)
|
||||||
@@ -789,15 +781,11 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
|
|
||||||
# Task-specific prompts (from GLM-OCR SDK config.yaml)
|
# Task-specific prompts (from GLM-OCR SDK config.yaml)
|
||||||
_TASK_PROMPTS: dict[str, str] = {
|
_TASK_PROMPTS: dict[str, str] = {
|
||||||
"text": "Text Recognition:",
|
"text": "Text Recognition. If the content is a formula, please output display latex code, else output text",
|
||||||
"formula": "Formula Recognition:",
|
"formula": "Formula Recognition:",
|
||||||
"table": "Table Recognition:",
|
"table": "Table Recognition:",
|
||||||
}
|
}
|
||||||
_DEFAULT_PROMPT = (
|
_DEFAULT_PROMPT = "Text Recognition. If the content is a formula, please output display latex code, else output text"
|
||||||
"Recognize the text in the image and output in Markdown format. "
|
|
||||||
"Preserve the original layout (headings/paragraphs/tables/formulas). "
|
|
||||||
"Do not fabricate content that does not exist in the image."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GLMOCREndToEndService(OCRServiceBase):
|
class GLMOCREndToEndService(OCRServiceBase):
|
||||||
@@ -880,13 +868,14 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
"""
|
"""
|
||||||
# 1. Layout detection
|
# 1. Layout detection
|
||||||
img_h, img_w = image.shape[:2]
|
img_h, img_w = image.shape[:2]
|
||||||
layout_info = self.layout_detector.detect(image)
|
padded_image = self.image_processor.add_padding(image)
|
||||||
|
layout_info = self.layout_detector.detect(padded_image)
|
||||||
|
|
||||||
# Sort regions in reading order: top-to-bottom, left-to-right
|
# Sort regions in reading order: top-to-bottom, left-to-right
|
||||||
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
|
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
|
||||||
|
|
||||||
# 3. OCR: per-region (parallel) or full-image fallback
|
# 3. OCR: per-region (parallel) or full-image fallback
|
||||||
if not layout_info.regions:
|
if not layout_info.regions or (len(layout_info.regions) == 1 and not layout_info.MixedRecognition):
|
||||||
# No layout detected → assume it's a formula, use formula recognition
|
# No layout detected → assume it's a formula, use formula recognition
|
||||||
logger.info("No layout regions detected, treating image as formula")
|
logger.info("No layout regions detected, treating image as formula")
|
||||||
raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"])
|
raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"])
|
||||||
@@ -902,7 +891,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
if region.type == "figure":
|
if region.type == "figure":
|
||||||
continue
|
continue
|
||||||
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
||||||
cropped = image[y1:y2, x1:x2]
|
cropped = padded_image[y1:y2, x1:x2]
|
||||||
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
|
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping region idx=%d (label=%s): crop too small %s",
|
"Skipping region idx=%d (label=%s): crop too small %s",
|
||||||
@@ -921,10 +910,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
# Parallel OCR calls
|
# Parallel OCR calls
|
||||||
raw_results: dict[int, str] = {}
|
raw_results: dict[int, str] = {}
|
||||||
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
|
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
|
||||||
future_map = {
|
future_map = {ex.submit(self._call_vllm, cropped, prompt): idx for idx, region, cropped, prompt in tasks}
|
||||||
ex.submit(self._call_vllm, cropped, prompt): idx
|
|
||||||
for idx, region, cropped, prompt in tasks
|
|
||||||
}
|
|
||||||
for future in as_completed(future_map):
|
for future in as_completed(future_map):
|
||||||
idx = future_map[future]
|
idx = future_map[future]
|
||||||
try:
|
try:
|
||||||
|
|||||||
35
tests/tools/layout.py
Normal file
35
tests/tools/layout.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import cv2
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.services.layout_detector import LayoutDetector
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
def debug_layout_detector():
|
||||||
|
layout_detector = LayoutDetector()
|
||||||
|
image = cv2.imread("test/image2.png")
|
||||||
|
|
||||||
|
print(f"Image shape: {image.shape}")
|
||||||
|
|
||||||
|
# padded_image = ImageProcessor(padding_ratio=0.15).add_padding(image)
|
||||||
|
layout_info = layout_detector.detect(image)
|
||||||
|
|
||||||
|
# draw the layout info and label
|
||||||
|
for region in layout_info.regions:
|
||||||
|
x1, y1, x2, y2 = region.bbox
|
||||||
|
cv2.putText(
|
||||||
|
image,
|
||||||
|
region.native_label,
|
||||||
|
(int(x1), int(y1)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 255),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
|
||||||
|
cv2.imwrite("test/layout_debug.png", image)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
debug_layout_detector()
|
||||||
Reference in New Issue
Block a user