Merge pull request 'fix: remove padding from GLMOCREndToEndService and clean up ruff violations' (#2) from fix/tag into main
Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from app.core.dependencies import get_converter
|
from app.core.dependencies import get_converter
|
||||||
from app.schemas.convert import MarkdownToDocxRequest, LatexToOmmlRequest, LatexToOmmlResponse
|
from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest
|
||||||
from app.services.converter import Converter
|
from app.services.converter import Converter
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import uuid
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
|
||||||
from app.core.dependencies import (
|
from app.core.dependencies import (
|
||||||
get_image_processor,
|
|
||||||
get_glmocr_endtoend_service,
|
get_glmocr_endtoend_service,
|
||||||
|
get_image_processor,
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_logger, RequestIDAdapter
|
from app.core.logging_config import RequestIDAdapter, get_logger
|
||||||
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
from app.services.ocr_service import GLMOCREndToEndService
|
from app.services.ocr_service import GLMOCREndToEndService
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Application dependencies."""
|
"""Application dependencies."""
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.services.converter import Converter
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
from app.services.layout_detector import LayoutDetector
|
from app.services.layout_detector import LayoutDetector
|
||||||
from app.services.ocr_service import GLMOCREndToEndService
|
from app.services.ocr_service import GLMOCREndToEndService
|
||||||
from app.services.converter import Converter
|
|
||||||
from app.core.config import get_settings
|
|
||||||
|
|
||||||
# Global instances (initialized on startup)
|
# Global instances (initialized on startup)
|
||||||
_layout_detector: LayoutDetector | None = None
|
_layout_detector: LayoutDetector | None = None
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
@@ -18,10 +18,10 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
|
|||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
backupCount: int = 30,
|
backupCount: int = 30,
|
||||||
maxBytes: int = 100 * 1024 * 1024, # 100MB
|
maxBytes: int = 100 * 1024 * 1024, # 100MB
|
||||||
encoding: Optional[str] = None,
|
encoding: str | None = None,
|
||||||
delay: bool = False,
|
delay: bool = False,
|
||||||
utc: bool = False,
|
utc: bool = False,
|
||||||
atTime: Optional[Any] = None,
|
atTime: Any | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize handler with both time and size rotation.
|
"""Initialize handler with both time and size rotation.
|
||||||
|
|
||||||
@@ -58,14 +58,14 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler)
|
|||||||
if self.stream is None:
|
if self.stream is None:
|
||||||
self.stream = self._open()
|
self.stream = self._open()
|
||||||
if self.maxBytes > 0:
|
if self.maxBytes > 0:
|
||||||
msg = "%s\n" % self.format(record)
|
msg = f"{self.format(record)}\n"
|
||||||
self.stream.seek(0, 2) # Seek to end
|
self.stream.seek(0, 2) # Seek to end
|
||||||
if self.stream.tell() + len(msg) >= self.maxBytes:
|
if self.stream.tell() + len(msg) >= self.maxBytes:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
|
def setup_logging(log_dir: str | None = None) -> logging.Logger:
|
||||||
"""Setup application logging with rotation by day and size.
|
"""Setup application logging with rotation by day and size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -134,7 +134,7 @@ def setup_logging(log_dir: Optional[str] = None) -> logging.Logger:
|
|||||||
|
|
||||||
|
|
||||||
# Global logger instance
|
# Global logger instance
|
||||||
_logger: Optional[logging.Logger] = None
|
_logger: logging.Logger | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_logger() -> logging.Logger:
|
def get_logger() -> logging.Logger:
|
||||||
|
|||||||
@@ -36,4 +36,3 @@ class LatexToOmmlResponse(BaseModel):
|
|||||||
"""Response body for LaTeX to OMML conversion endpoint."""
|
"""Response body for LaTeX to OMML conversion endpoint."""
|
||||||
|
|
||||||
omml: str = Field("", description="OMML (Office Math Markup Language) representation")
|
omml: str = Field("", description="OMML (Office Math Markup Language) representation")
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ class LayoutRegion(BaseModel):
|
|||||||
"""A detected layout region in the document."""
|
"""A detected layout region in the document."""
|
||||||
|
|
||||||
type: str = Field(..., description="Region type: text, formula, table, figure")
|
type: str = Field(..., description="Region type: text, formula, table, figure")
|
||||||
native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)")
|
native_label: str = Field(
|
||||||
|
"", description="Raw label before type mapping (e.g. doc_title, formula_number)"
|
||||||
|
)
|
||||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||||
confidence: float = Field(..., description="Detection confidence score")
|
confidence: float = Field(..., description="Detection confidence score")
|
||||||
score: float = Field(..., description="Detection score")
|
score: float = Field(..., description="Detection score")
|
||||||
@@ -41,10 +43,15 @@ class ImageOCRRequest(BaseModel):
|
|||||||
class ImageOCRResponse(BaseModel):
|
class ImageOCRResponse(BaseModel):
|
||||||
"""Response body for image OCR endpoint."""
|
"""Response body for image OCR endpoint."""
|
||||||
|
|
||||||
latex: str = Field("", description="LaTeX representation of the content (empty if mixed content)")
|
latex: str = Field(
|
||||||
|
"", description="LaTeX representation of the content (empty if mixed content)"
|
||||||
|
)
|
||||||
markdown: str = Field("", description="Markdown representation of the content")
|
markdown: str = Field("", description="Markdown representation of the content")
|
||||||
mathml: str = Field("", description="Standard MathML representation (empty if mixed content)")
|
mathml: str = Field("", description="Standard MathML representation (empty if mixed content)")
|
||||||
mml: str = Field("", description="XML MathML with mml: namespace prefix (empty if mixed content)")
|
mml: str = Field(
|
||||||
|
"", description="XML MathML with mml: namespace prefix (empty if mixed content)"
|
||||||
|
)
|
||||||
layout_info: LayoutInfo = Field(default_factory=LayoutInfo)
|
layout_info: LayoutInfo = Field(default_factory=LayoutInfo)
|
||||||
recognition_mode: str = Field("", description="Recognition mode used: mixed_recognition or formula_recognition")
|
recognition_mode: str = Field(
|
||||||
|
"", description="Recognition mode used: mixed_recognition or formula_recognition"
|
||||||
|
)
|
||||||
|
|||||||
@@ -112,14 +112,18 @@ class Converter:
|
|||||||
# Pre-compiled regex patterns for preprocessing
|
# Pre-compiled regex patterns for preprocessing
|
||||||
_RE_VSPACE = re.compile(r"\\\[1mm\]")
|
_RE_VSPACE = re.compile(r"\\\[1mm\]")
|
||||||
_RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL)
|
_RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL)
|
||||||
_RE_BLOCK_FORMULA_LINE = re.compile(r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL)
|
_RE_BLOCK_FORMULA_LINE = re.compile(
|
||||||
|
r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL
|
||||||
|
)
|
||||||
_RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>')
|
_RE_ARITHMATEX = re.compile(r'<span class="arithmatex">(.*?)</span>')
|
||||||
_RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)")
|
_RE_INLINE_SPACE = re.compile(r"(?<!\$)\$ +(.+?) +\$(?!\$)")
|
||||||
_RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}")
|
_RE_ARRAY_SPECIFIER = re.compile(r"\\begin\{array\}\{([^}]+)\}")
|
||||||
_RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+")
|
_RE_LEFT_BRACE = re.compile(r"\\left\\\{\s+")
|
||||||
_RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}")
|
_RE_RIGHT_BRACE = re.compile(r"\s+\\right\\\}")
|
||||||
_RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL)
|
_RE_CASES = re.compile(r"\\begin\{cases\}(.*?)\\end\{cases\}", re.DOTALL)
|
||||||
_RE_ALIGNED_BRACE = re.compile(r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL)
|
_RE_ALIGNED_BRACE = re.compile(
|
||||||
|
r"\\left\\\{\\begin\{aligned\}(.*?)\\end\{aligned\}\\right\.", re.DOTALL
|
||||||
|
)
|
||||||
_RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL)
|
_RE_ALIGNED = re.compile(r"\\begin\{aligned\}(.*?)\\end\{aligned\}", re.DOTALL)
|
||||||
_RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL)
|
_RE_TAG = re.compile(r"\$\$(.*?)\\tag\s*\{([^}]+)\}\s*\$\$", re.DOTALL)
|
||||||
_RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL)
|
_RE_VMATRIX = re.compile(r"\\begin\{vmatrix\}(.*?)\\end\{vmatrix\}", re.DOTALL)
|
||||||
@@ -368,7 +372,9 @@ class Converter:
|
|||||||
mathml = latex_to_mathml(latex_formula)
|
mathml = latex_to_mathml(latex_formula)
|
||||||
return Converter._postprocess_mathml_for_word(mathml)
|
return Converter._postprocess_mathml_for_word(mathml)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}") from e
|
raise RuntimeError(
|
||||||
|
f"MathML conversion failed: {pandoc_error}. latex2mathml fallback also failed: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _postprocess_mathml_for_word(mathml: str) -> str:
|
def _postprocess_mathml_for_word(mathml: str) -> str:
|
||||||
@@ -583,7 +589,6 @@ class Converter:
|
|||||||
"⇓": "⇓", # Downarrow
|
"⇓": "⇓", # Downarrow
|
||||||
"↕": "↕", # updownarrow
|
"↕": "↕", # updownarrow
|
||||||
"⇕": "⇕", # Updownarrow
|
"⇕": "⇕", # Updownarrow
|
||||||
"≠": "≠", # ne
|
|
||||||
"≪": "≪", # ll
|
"≪": "≪", # ll
|
||||||
"≫": "≫", # gg
|
"≫": "≫", # gg
|
||||||
"⩽": "⩽", # leqslant
|
"⩽": "⩽", # leqslant
|
||||||
@@ -962,7 +967,7 @@ class Converter:
|
|||||||
"""Export to DOCX format using pypandoc."""
|
"""Export to DOCX format using pypandoc."""
|
||||||
extra_args = [
|
extra_args = [
|
||||||
"--highlight-style=pygments",
|
"--highlight-style=pygments",
|
||||||
f"--reference-doc=app/pkg/reference.docx",
|
"--reference-doc=app/pkg/reference.docx",
|
||||||
]
|
]
|
||||||
pypandoc.convert_file(
|
pypandoc.convert_file(
|
||||||
input_path,
|
input_path,
|
||||||
|
|||||||
@@ -1,26 +1,10 @@
|
|||||||
"""GLM-OCR postprocessing logic adapted for this project.
|
|
||||||
|
|
||||||
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
|
|
||||||
glm-ocr/glmocr/utils/result_postprocess_utils.py.
|
|
||||||
|
|
||||||
Covers:
|
|
||||||
- Repeated-content / hallucination detection
|
|
||||||
- Per-region content cleaning and formatting (titles, bullets, formulas)
|
|
||||||
- formula_number merging (→ \\tag{})
|
|
||||||
- Hyphenated text-block merging (via wordfreq)
|
|
||||||
- Missing bullet-point detection
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from wordfreq import zipf_frequency
|
from wordfreq import zipf_frequency
|
||||||
@@ -29,13 +13,14 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_WORDFREQ_AVAILABLE = False
|
_WORDFREQ_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# result_postprocess_utils (ported)
|
# result_postprocess_utils (ported)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
|
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None:
|
||||||
"""Detect and truncate a consecutively-repeated pattern.
|
"""Detect and truncate a consecutively-repeated pattern.
|
||||||
|
|
||||||
Returns the string with the repeat removed, or None if not found.
|
Returns the string with the repeat removed, or None if not found.
|
||||||
@@ -49,7 +34,13 @@ def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 1
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
|
r"(.{"
|
||||||
|
+ str(min_unit_len)
|
||||||
|
+ ","
|
||||||
|
+ str(max_unit_len)
|
||||||
|
+ r"}?)\1{"
|
||||||
|
+ str(min_repeats - 1)
|
||||||
|
+ ",}",
|
||||||
re.DOTALL,
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
match = pattern.search(s)
|
match = pattern.search(s)
|
||||||
@@ -83,7 +74,9 @@ def clean_repeated_content(
|
|||||||
if count >= line_threshold and (count / total_lines) >= 0.8:
|
if count >= line_threshold and (count / total_lines) >= 0.8:
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
if line == common:
|
if line == common:
|
||||||
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
|
consecutive = sum(
|
||||||
|
1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common
|
||||||
|
)
|
||||||
if consecutive >= 3:
|
if consecutive >= 3:
|
||||||
original_lines = content.split("\n")
|
original_lines = content.split("\n")
|
||||||
non_empty_count = 0
|
non_empty_count = 0
|
||||||
@@ -121,7 +114,7 @@ def clean_formula_number(number_content: str) -> str:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
||||||
_LABEL_TO_CATEGORY: Dict[str, str] = {
|
_LABEL_TO_CATEGORY: dict[str, str] = {
|
||||||
# text
|
# text
|
||||||
"abstract": "text",
|
"abstract": "text",
|
||||||
"algorithm": "text",
|
"algorithm": "text",
|
||||||
@@ -157,7 +150,7 @@ class GLMResultFormatter:
|
|||||||
# Public entry-point
|
# Public entry-point
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
|
|
||||||
def process(self, regions: List[Dict[str, Any]]) -> str:
|
def process(self, regions: list[dict[str, Any]]) -> str:
|
||||||
"""Run the full postprocessing pipeline and return Markdown.
|
"""Run the full postprocessing pipeline and return Markdown.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -175,7 +168,7 @@ class GLMResultFormatter:
|
|||||||
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
|
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
|
||||||
|
|
||||||
# Per-region cleaning + formatting
|
# Per-region cleaning + formatting
|
||||||
processed: List[Dict] = []
|
processed: list[dict] = []
|
||||||
for item in items:
|
for item in items:
|
||||||
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
||||||
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
||||||
@@ -199,7 +192,7 @@ class GLMResultFormatter:
|
|||||||
processed = self._format_bullet_points(processed)
|
processed = self._format_bullet_points(processed)
|
||||||
|
|
||||||
# Assemble Markdown
|
# Assemble Markdown
|
||||||
parts: List[str] = []
|
parts: list[str] = []
|
||||||
for item in processed:
|
for item in processed:
|
||||||
content = item.get("content") or ""
|
content = item.get("content") or ""
|
||||||
if item["label"] == "image":
|
if item["label"] == "image":
|
||||||
@@ -263,11 +256,15 @@ class GLMResultFormatter:
|
|||||||
if label == "formula":
|
if label == "formula":
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
||||||
if content.startswith(s) and content.endswith(e):
|
if content.startswith(s):
|
||||||
content = content[len(s) : -len(e)].strip()
|
content = content[len(s) :].strip()
|
||||||
|
if content.endswith(e):
|
||||||
|
content = content[: -len(e)].strip()
|
||||||
break
|
break
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning("Skipping formula region with empty content after stripping delimiters")
|
logger.warning(
|
||||||
|
"Skipping formula region with empty content after stripping delimiters"
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
content = "$$\n" + content + "\n$$"
|
content = "$$\n" + content + "\n$$"
|
||||||
|
|
||||||
@@ -296,12 +293,12 @@ class GLMResultFormatter:
|
|||||||
# Structural merges
|
# Structural merges
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
|
|
||||||
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
|
def _merge_formula_numbers(self, items: list[dict]) -> list[dict]:
|
||||||
"""Merge formula_number region into adjacent formula with \\tag{}."""
|
"""Merge formula_number region into adjacent formula with \\tag{}."""
|
||||||
if not items:
|
if not items:
|
||||||
return items
|
return items
|
||||||
|
|
||||||
merged: List[Dict] = []
|
merged: list[dict] = []
|
||||||
skip: set = set()
|
skip: set = set()
|
||||||
|
|
||||||
for i, block in enumerate(items):
|
for i, block in enumerate(items):
|
||||||
@@ -317,7 +314,9 @@ class GLMResultFormatter:
|
|||||||
formula_content = items[i + 1].get("content", "")
|
formula_content = items[i + 1].get("content", "")
|
||||||
merged_block = deepcopy(items[i + 1])
|
merged_block = deepcopy(items[i + 1])
|
||||||
if formula_content.endswith("\n$$"):
|
if formula_content.endswith("\n$$"):
|
||||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
merged_block["content"] = (
|
||||||
|
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||||
|
)
|
||||||
merged.append(merged_block)
|
merged.append(merged_block)
|
||||||
skip.add(i + 1)
|
skip.add(i + 1)
|
||||||
continue # always skip the formula_number block itself
|
continue # always skip the formula_number block itself
|
||||||
@@ -329,7 +328,9 @@ class GLMResultFormatter:
|
|||||||
formula_content = block.get("content", "")
|
formula_content = block.get("content", "")
|
||||||
merged_block = deepcopy(block)
|
merged_block = deepcopy(block)
|
||||||
if formula_content.endswith("\n$$"):
|
if formula_content.endswith("\n$$"):
|
||||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
merged_block["content"] = (
|
||||||
|
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||||
|
)
|
||||||
merged.append(merged_block)
|
merged.append(merged_block)
|
||||||
skip.add(i + 1)
|
skip.add(i + 1)
|
||||||
continue
|
continue
|
||||||
@@ -340,12 +341,12 @@ class GLMResultFormatter:
|
|||||||
block["index"] = i
|
block["index"] = i
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
|
def _merge_text_blocks(self, items: list[dict]) -> list[dict]:
|
||||||
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
|
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
|
||||||
if not items or not _WORDFREQ_AVAILABLE:
|
if not items or not _WORDFREQ_AVAILABLE:
|
||||||
return items
|
return items
|
||||||
|
|
||||||
merged: List[Dict] = []
|
merged: list[dict] = []
|
||||||
skip: set = set()
|
skip: set = set()
|
||||||
|
|
||||||
for i, block in enumerate(items):
|
for i, block in enumerate(items):
|
||||||
@@ -389,7 +390,9 @@ class GLMResultFormatter:
|
|||||||
block["index"] = i
|
block["index"] = i
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
|
def _format_bullet_points(
|
||||||
|
self, items: list[dict], left_align_threshold: float = 10.0
|
||||||
|
) -> list[dict]:
|
||||||
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
||||||
if len(items) < 3:
|
if len(items) < 3:
|
||||||
return items
|
return items
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
"""PP-DocLayoutV3 wrapper for document layout detection."""
|
"""PP-DocLayoutV3 wrapper for document layout detection."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
|
||||||
from app.core.config import get_settings
|
|
||||||
from app.services.layout_postprocess import apply_layout_postprocess
|
|
||||||
from paddleocr import LayoutDetection
|
from paddleocr import LayoutDetection
|
||||||
from typing import Optional
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||||
|
from app.services.layout_postprocess import apply_layout_postprocess
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
@@ -14,7 +13,7 @@ settings = get_settings()
|
|||||||
class LayoutDetector:
|
class LayoutDetector:
|
||||||
"""Layout detector for PP-DocLayoutV2."""
|
"""Layout detector for PP-DocLayoutV2."""
|
||||||
|
|
||||||
_layout_detector: Optional[LayoutDetection] = None
|
_layout_detector: LayoutDetection | None = None
|
||||||
|
|
||||||
# PP-DocLayoutV2 class ID to label mapping
|
# PP-DocLayoutV2 class ID to label mapping
|
||||||
CLS_ID_TO_LABEL: dict[int, str] = {
|
CLS_ID_TO_LABEL: dict[int, str] = {
|
||||||
@@ -156,10 +155,11 @@ class LayoutDetector:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.services.image_processor import ImageProcessor
|
|
||||||
from app.services.converter import Converter
|
from app.services.converter import Converter
|
||||||
from app.services.ocr_service import OCRService
|
from app.services.image_processor import ImageProcessor
|
||||||
|
from app.services.ocr_service import GLMOCREndToEndService
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
@@ -169,15 +169,15 @@ if __name__ == "__main__":
|
|||||||
converter = Converter()
|
converter = Converter()
|
||||||
|
|
||||||
# Initialize OCR service
|
# Initialize OCR service
|
||||||
ocr_service = OCRService(
|
ocr_service = GLMOCREndToEndService(
|
||||||
vl_server_url=settings.paddleocr_vl_url,
|
vl_server_url=settings.glm_ocr_url,
|
||||||
layout_detector=layout_detector,
|
layout_detector=layout_detector,
|
||||||
image_processor=image_processor,
|
image_processor=image_processor,
|
||||||
converter=converter,
|
converter=converter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load test image
|
# Load test image
|
||||||
image_path = "test/timeout.jpg"
|
image_path = "test/image2.png"
|
||||||
image = cv2.imread(image_path)
|
image = cv2.imread(image_path)
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
|
|||||||
@@ -15,16 +15,14 @@ the quality of the GLM-OCR SDK's layout pipeline.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Primitive geometry helpers
|
# Primitive geometry helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def iou(box1: List[float], box2: List[float]) -> float:
|
|
||||||
|
def iou(box1: list[float], box2: list[float]) -> float:
|
||||||
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
|
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
|
||||||
x1, y1, x2, y2 = box1
|
x1, y1, x2, y2 = box1
|
||||||
x1_p, y1_p, x2_p, y2_p = box2
|
x1_p, y1_p, x2_p, y2_p = box2
|
||||||
@@ -41,7 +39,7 @@ def iou(box1: List[float], box2: List[float]) -> float:
|
|||||||
return inter_area / float(box1_area + box2_area - inter_area)
|
return inter_area / float(box1_area + box2_area - inter_area)
|
||||||
|
|
||||||
|
|
||||||
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
|
def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool:
|
||||||
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
|
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
|
||||||
|
|
||||||
box format: [cls_id, score, x1, y1, x2, y2]
|
box format: [cls_id, score, x1, y1, x2, y2]
|
||||||
@@ -66,11 +64,12 @@ def is_contained(box1: List[float], box2: List[float], overlap_threshold: float
|
|||||||
# NMS
|
# NMS
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def nms(
|
def nms(
|
||||||
boxes: np.ndarray,
|
boxes: np.ndarray,
|
||||||
iou_same: float = 0.6,
|
iou_same: float = 0.6,
|
||||||
iou_diff: float = 0.98,
|
iou_diff: float = 0.98,
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
|
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -83,7 +82,7 @@ def nms(
|
|||||||
"""
|
"""
|
||||||
scores = boxes[:, 1]
|
scores = boxes[:, 1]
|
||||||
indices = np.argsort(scores)[::-1].tolist()
|
indices = np.argsort(scores)[::-1].tolist()
|
||||||
selected: List[int] = []
|
selected: list[int] = []
|
||||||
|
|
||||||
while indices:
|
while indices:
|
||||||
current = indices[0]
|
current = indices[0]
|
||||||
@@ -114,10 +113,10 @@ _PRESERVE_LABELS = {"image", "seal", "chart"}
|
|||||||
|
|
||||||
def check_containment(
|
def check_containment(
|
||||||
boxes: np.ndarray,
|
boxes: np.ndarray,
|
||||||
preserve_cls_ids: Optional[set] = None,
|
preserve_cls_ids: set | None = None,
|
||||||
category_index: Optional[int] = None,
|
category_index: int | None = None,
|
||||||
mode: Optional[str] = None,
|
mode: str | None = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""Compute containment flags for each box.
|
"""Compute containment flags for each box.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -160,9 +159,10 @@ def check_containment(
|
|||||||
# Box expansion (unclip)
|
# Box expansion (unclip)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def unclip_boxes(
|
def unclip_boxes(
|
||||||
boxes: np.ndarray,
|
boxes: np.ndarray,
|
||||||
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
|
unclip_ratio: float | tuple[float, float] | dict | list | None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Expand bounding boxes by the given ratio.
|
"""Expand bounding boxes by the given ratio.
|
||||||
|
|
||||||
@@ -215,13 +215,14 @@ def unclip_boxes(
|
|||||||
# Main entry-point
|
# Main entry-point
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def apply_layout_postprocess(
|
def apply_layout_postprocess(
|
||||||
boxes: List[Dict],
|
boxes: list[dict],
|
||||||
img_size: Tuple[int, int],
|
img_size: tuple[int, int],
|
||||||
layout_nms: bool = True,
|
layout_nms: bool = True,
|
||||||
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
|
layout_unclip_ratio: float | tuple | dict | None = None,
|
||||||
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
|
layout_merge_bboxes_mode: str | dict | None = "large",
|
||||||
) -> List[Dict]:
|
) -> list[dict]:
|
||||||
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
|
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -250,7 +251,7 @@ def apply_layout_postprocess(
|
|||||||
arr_rows.append([cls_id, score, x1, y1, x2, y2])
|
arr_rows.append([cls_id, score, x1, y1, x2, y2])
|
||||||
boxes_array = np.array(arr_rows, dtype=float)
|
boxes_array = np.array(arr_rows, dtype=float)
|
||||||
|
|
||||||
all_labels: List[str] = [b.get("label", "") for b in boxes]
|
all_labels: list[str] = [b.get("label", "") for b in boxes]
|
||||||
|
|
||||||
# 1. NMS ---------------------------------------------------------------- #
|
# 1. NMS ---------------------------------------------------------------- #
|
||||||
if layout_nms and len(boxes_array) > 1:
|
if layout_nms and len(boxes_array) > 1:
|
||||||
@@ -262,17 +263,14 @@ def apply_layout_postprocess(
|
|||||||
if len(boxes_array) > 1:
|
if len(boxes_array) > 1:
|
||||||
img_area = img_width * img_height
|
img_area = img_width * img_height
|
||||||
area_thres = 0.82 if img_width > img_height else 0.93
|
area_thres = 0.82 if img_width > img_height else 0.93
|
||||||
image_cls_ids = {
|
|
||||||
int(boxes_array[i, 0])
|
|
||||||
for i, lbl in enumerate(all_labels)
|
|
||||||
if lbl == "image"
|
|
||||||
}
|
|
||||||
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||||||
for i, lbl in enumerate(all_labels):
|
for i, lbl in enumerate(all_labels):
|
||||||
if lbl == "image":
|
if lbl == "image":
|
||||||
x1, y1, x2, y2 = boxes_array[i, 2:6]
|
x1, y1, x2, y2 = boxes_array[i, 2:6]
|
||||||
x1 = max(0.0, x1); y1 = max(0.0, y1)
|
x1 = max(0.0, x1)
|
||||||
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
|
y1 = max(0.0, y1)
|
||||||
|
x2 = min(float(img_width), x2)
|
||||||
|
y2 = min(float(img_height), y2)
|
||||||
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
|
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
|
||||||
keep_mask[i] = False
|
keep_mask[i] = False
|
||||||
boxes_array = boxes_array[keep_mask]
|
boxes_array = boxes_array[keep_mask]
|
||||||
@@ -281,9 +279,7 @@ def apply_layout_postprocess(
|
|||||||
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
|
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
|
||||||
if layout_merge_bboxes_mode and len(boxes_array) > 1:
|
if layout_merge_bboxes_mode and len(boxes_array) > 1:
|
||||||
preserve_cls_ids = {
|
preserve_cls_ids = {
|
||||||
int(boxes_array[i, 0])
|
int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS
|
||||||
for i, lbl in enumerate(all_labels)
|
|
||||||
if lbl in _PRESERVE_LABELS
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(layout_merge_bboxes_mode, str):
|
if isinstance(layout_merge_bboxes_mode, str):
|
||||||
@@ -321,7 +317,7 @@ def apply_layout_postprocess(
|
|||||||
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
|
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
|
||||||
|
|
||||||
# 5. Clamp to image boundaries + skip invalid -------------------------- #
|
# 5. Clamp to image boundaries + skip invalid -------------------------- #
|
||||||
result: List[Dict] = []
|
result: list[dict] = []
|
||||||
for i, row in enumerate(boxes_array):
|
for i, row in enumerate(boxes_array):
|
||||||
cls_id = int(row[0])
|
cls_id = int(row[0])
|
||||||
score = float(row[1])
|
score = float(row[1])
|
||||||
@@ -333,11 +329,13 @@ def apply_layout_postprocess(
|
|||||||
if x1 >= x2 or y1 >= y2:
|
if x1 >= x2 or y1 >= y2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result.append({
|
result.append(
|
||||||
|
{
|
||||||
"cls_id": cls_id,
|
"cls_id": cls_id,
|
||||||
"label": all_labels[i],
|
"label": all_labels[i],
|
||||||
"score": score,
|
"score": score,
|
||||||
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
|
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -878,12 +878,9 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
|
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
|
||||||
"""
|
"""
|
||||||
# 1. Padding
|
# 1. Layout detection
|
||||||
padded = self.image_processor.add_padding(image)
|
img_h, img_w = image.shape[:2]
|
||||||
img_h, img_w = padded.shape[:2]
|
layout_info = self.layout_detector.detect(image)
|
||||||
|
|
||||||
# 2. Layout detection
|
|
||||||
layout_info = self.layout_detector.detect(padded)
|
|
||||||
|
|
||||||
# Sort regions in reading order: top-to-bottom, left-to-right
|
# Sort regions in reading order: top-to-bottom, left-to-right
|
||||||
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
|
layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0]))
|
||||||
@@ -892,7 +889,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
if not layout_info.regions:
|
if not layout_info.regions:
|
||||||
# No layout detected → assume it's a formula, use formula recognition
|
# No layout detected → assume it's a formula, use formula recognition
|
||||||
logger.info("No layout regions detected, treating image as formula")
|
logger.info("No layout regions detected, treating image as formula")
|
||||||
raw_content = self._call_vllm(padded, _TASK_PROMPTS["formula"])
|
raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"])
|
||||||
# Format as display formula markdown
|
# Format as display formula markdown
|
||||||
formatted_content = raw_content.strip()
|
formatted_content = raw_content.strip()
|
||||||
if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")):
|
if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")):
|
||||||
@@ -905,7 +902,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
if region.type == "figure":
|
if region.type == "figure":
|
||||||
continue
|
continue
|
||||||
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
||||||
cropped = padded[y1:y2, x1:x2]
|
cropped = image[y1:y2, x1:x2]
|
||||||
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
|
if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping region idx=%d (label=%s): crop too small %s",
|
"Skipping region idx=%d (label=%s): crop too small %s",
|
||||||
@@ -918,7 +915,7 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
tasks.append((idx, region, cropped, prompt))
|
tasks.append((idx, region, cropped, prompt))
|
||||||
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
raw_content = self._call_vllm(image, _DEFAULT_PROMPT)
|
||||||
markdown_content = self._formatter._clean_content(raw_content)
|
markdown_content = self._formatter._clean_content(raw_content)
|
||||||
else:
|
else:
|
||||||
# Parallel OCR calls
|
# Parallel OCR calls
|
||||||
@@ -965,17 +962,3 @@ class GLMOCREndToEndService(OCRServiceBase):
|
|||||||
logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e)
|
logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e)
|
||||||
|
|
||||||
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
|
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mineru_service = MineruOCRService()
|
|
||||||
image = cv2.imread("test/formula2.jpg")
|
|
||||||
image_numpy = np.array(image)
|
|
||||||
# Encode image to bytes (as done in API layer)
|
|
||||||
success, encoded_image = cv2.imencode(".png", image_numpy)
|
|
||||||
if not success:
|
|
||||||
raise RuntimeError("Failed to encode image")
|
|
||||||
image_bytes = BytesIO(encoded_image.tobytes())
|
|
||||||
image_bytes.seek(0)
|
|
||||||
ocr_result = mineru_service.recognize(image_bytes)
|
|
||||||
print(ocr_result)
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
@@ -35,7 +34,9 @@ def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64():
|
|||||||
client = _build_client()
|
client = _build_client()
|
||||||
|
|
||||||
missing = client.post("/ocr", json={})
|
missing = client.post("/ocr", json={})
|
||||||
both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"})
|
both = client.post(
|
||||||
|
"/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}
|
||||||
|
)
|
||||||
|
|
||||||
assert missing.status_code == 422
|
assert missing.status_code == 422
|
||||||
assert both.status_code == 422
|
assert both.status_code == 422
|
||||||
|
|||||||
@@ -57,12 +57,22 @@ def test_merge_formula_numbers_merges_before_and_after_formula():
|
|||||||
before = formatter._merge_formula_numbers(
|
before = formatter._merge_formula_numbers(
|
||||||
[
|
[
|
||||||
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
|
{"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"},
|
||||||
{"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
{
|
||||||
|
"index": 1,
|
||||||
|
"label": "formula",
|
||||||
|
"native_label": "display_formula",
|
||||||
|
"content": "$$\nx+y\n$$",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
after = formatter._merge_formula_numbers(
|
after = formatter._merge_formula_numbers(
|
||||||
[
|
[
|
||||||
{"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"},
|
{
|
||||||
|
"index": 0,
|
||||||
|
"label": "formula",
|
||||||
|
"native_label": "display_formula",
|
||||||
|
"content": "$$\nx+y\n$$",
|
||||||
|
},
|
||||||
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
|
{"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
|||||||
|
|
||||||
calls = {}
|
calls = {}
|
||||||
|
|
||||||
def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode):
|
def fake_apply_layout_postprocess(
|
||||||
|
boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode
|
||||||
|
):
|
||||||
calls["args"] = {
|
calls["args"] = {
|
||||||
"boxes": boxes,
|
"boxes": boxes,
|
||||||
"img_size": img_size,
|
"img_size": img_size,
|
||||||
@@ -33,7 +35,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch):
|
|||||||
}
|
}
|
||||||
return [boxes[0], boxes[2]]
|
return [boxes[0], boxes[2]]
|
||||||
|
|
||||||
monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess)
|
monkeypatch.setattr(
|
||||||
|
"app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess
|
||||||
|
)
|
||||||
|
|
||||||
image = np.zeros((200, 100, 3), dtype=np.uint8)
|
image = np.zeros((200, 100, 3), dtype=np.uint8)
|
||||||
info = detector.detect(image)
|
info = detector.detect(image)
|
||||||
|
|||||||
@@ -146,6 +146,4 @@ def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image()
|
|||||||
layout_merge_bboxes_mode=None,
|
layout_merge_bboxes_mode=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == [
|
assert result == [{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}]
|
||||||
{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ def test_encode_region_returns_decodable_base64_jpeg():
|
|||||||
image[:, :] = [0, 128, 255]
|
image[:, :] = [0, 128, 255]
|
||||||
|
|
||||||
encoded = service._encode_region(image)
|
encoded = service._encode_region(image)
|
||||||
decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR)
|
decoded = cv2.imdecode(
|
||||||
|
np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR
|
||||||
|
)
|
||||||
|
|
||||||
assert decoded.shape[:2] == image.shape[:2]
|
assert decoded.shape[:2] == image.shape[:2]
|
||||||
|
|
||||||
@@ -71,7 +73,9 @@ def test_call_vllm_builds_messages_and_returns_content():
|
|||||||
assert captured["model"] == "glm-ocr"
|
assert captured["model"] == "glm-ocr"
|
||||||
assert captured["max_tokens"] == 1024
|
assert captured["max_tokens"] == 1024
|
||||||
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
assert captured["messages"][0]["content"][0]["type"] == "image_url"
|
||||||
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith(
|
||||||
|
"data:image/jpeg;base64,"
|
||||||
|
)
|
||||||
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"}
|
||||||
|
|
||||||
|
|
||||||
@@ -98,9 +102,19 @@ def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch):
|
|||||||
|
|
||||||
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch):
|
||||||
regions = [
|
regions = [
|
||||||
LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9),
|
LayoutRegion(
|
||||||
LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8),
|
type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9
|
||||||
LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95),
|
),
|
||||||
|
LayoutRegion(
|
||||||
|
type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8
|
||||||
|
),
|
||||||
|
LayoutRegion(
|
||||||
|
type="formula",
|
||||||
|
native_label="display_formula",
|
||||||
|
bbox=[20, 20, 40, 40],
|
||||||
|
confidence=0.95,
|
||||||
|
score=0.95,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
service = _build_service(regions=regions)
|
service = _build_service(regions=regions)
|
||||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||||
|
|||||||
Reference in New Issue
Block a user