fix: add image padding for mineru
This commit is contained in:
@@ -53,5 +53,6 @@ def get_mineru_ocr_service() -> MineruOCRService:
|
|||||||
return MineruOCRService(
|
return MineruOCRService(
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
converter=get_converter(),
|
converter=get_converter(),
|
||||||
|
image_processor=get_image_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""PaddleOCR-VL client service for text and formula recognition."""
|
"""PaddleOCR-VL client service for text and formula recognition."""
|
||||||
|
|
||||||
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import requests
|
import requests
|
||||||
@@ -14,6 +15,82 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
|
_COMMANDS_NEED_SPACE = {
|
||||||
|
# operators / calculus
|
||||||
|
"cdot", "times", "div", "pm", "mp",
|
||||||
|
"int", "iint", "iiint", "oint", "sum", "prod", "lim",
|
||||||
|
# common functions
|
||||||
|
"sin", "cos", "tan", "cot", "sec", "csc",
|
||||||
|
"log", "ln", "exp",
|
||||||
|
# misc
|
||||||
|
"partial", "nabla",
|
||||||
|
}
|
||||||
|
|
||||||
|
_MATH_SEGMENT_PATTERN = re.compile(r"\$\$.*?\$\$|\$.*?\$", re.DOTALL)
|
||||||
|
_COMMAND_TOKEN_PATTERN = re.compile(r"\\[a-zA-Z]+")
|
||||||
|
|
||||||
|
# stage2: differentials inside math segments
|
||||||
|
_DIFFERENTIAL_UPPER_PATTERN = re.compile(r"(?<!\\)d([A-Z])")
|
||||||
|
_DIFFERENTIAL_LOWER_PATTERN = re.compile(r"(?<!\\)d([a-z])")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_glued_command_token(token: str) -> str:
|
||||||
|
"""Split OCR-glued LaTeX command token by whitelist longest-prefix.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- \\cdotdS -> \\cdot dS
|
||||||
|
- \\intdx -> \\int dx
|
||||||
|
"""
|
||||||
|
if not token.startswith("\\"):
|
||||||
|
return token
|
||||||
|
|
||||||
|
body = token[1:]
|
||||||
|
if len(body) < 2:
|
||||||
|
return token
|
||||||
|
|
||||||
|
best = None
|
||||||
|
# longest prefix that is in whitelist
|
||||||
|
for i in range(1, len(body)):
|
||||||
|
prefix = body[:i]
|
||||||
|
if prefix in _COMMANDS_NEED_SPACE:
|
||||||
|
best = prefix
|
||||||
|
|
||||||
|
if not best:
|
||||||
|
return token
|
||||||
|
|
||||||
|
suffix = body[len(best):]
|
||||||
|
if not suffix:
|
||||||
|
return token
|
||||||
|
|
||||||
|
return f"\\{best} {suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_math(expr: str) -> str:
|
||||||
|
"""Postprocess a *math* expression (already inside $...$ or $$...$$)."""
|
||||||
|
# stage1: split glued command tokens (e.g. \cdotdS)
|
||||||
|
expr = _COMMAND_TOKEN_PATTERN.sub(lambda m: _split_glued_command_token(m.group(0)), expr)
|
||||||
|
# stage2: normalize differentials (keep conservative)
|
||||||
|
expr = _DIFFERENTIAL_UPPER_PATTERN.sub(r"\\mathrm{d} \1", expr)
|
||||||
|
expr = _DIFFERENTIAL_LOWER_PATTERN.sub(r"d \1", expr)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_markdown(markdown_content: str) -> str:
|
||||||
|
"""Apply LaTeX postprocessing only within $...$ / $$...$$ segments."""
|
||||||
|
if not markdown_content:
|
||||||
|
return markdown_content
|
||||||
|
|
||||||
|
def _fix_segment(m: re.Match) -> str:
|
||||||
|
seg = m.group(0)
|
||||||
|
if seg.startswith("$$") and seg.endswith("$$"):
|
||||||
|
return f"$${_postprocess_math(seg[2:-2])}$$"
|
||||||
|
if seg.startswith("$") and seg.endswith("$"):
|
||||||
|
return f"${_postprocess_math(seg[1:-1])}$"
|
||||||
|
return seg
|
||||||
|
|
||||||
|
return _MATH_SEGMENT_PATTERN.sub(_fix_segment, markdown_content)
|
||||||
|
|
||||||
|
|
||||||
class OCRServiceBase(ABC):
|
class OCRServiceBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def recognize(self, image: np.ndarray) -> dict:
|
def recognize(self, image: np.ndarray) -> dict:
|
||||||
@@ -81,6 +158,7 @@ class OCRService(OCRServiceBase):
|
|||||||
for res in output:
|
for res in output:
|
||||||
markdown_content += res.markdown.get("markdown_texts", "")
|
markdown_content += res.markdown.get("markdown_texts", "")
|
||||||
|
|
||||||
|
markdown_content = _postprocess_markdown(markdown_content)
|
||||||
convert_result = self.converter.convert_to_formats(markdown_content)
|
convert_result = self.converter.convert_to_formats(markdown_content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -112,6 +190,7 @@ class OCRService(OCRServiceBase):
|
|||||||
for res in output:
|
for res in output:
|
||||||
markdown_content += res.markdown.get("markdown_texts", "")
|
markdown_content += res.markdown.get("markdown_texts", "")
|
||||||
|
|
||||||
|
markdown_content = _postprocess_markdown(markdown_content)
|
||||||
convert_result = self.converter.convert_to_formats(markdown_content)
|
convert_result = self.converter.convert_to_formats(markdown_content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -145,6 +224,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_url: str = "http://127.0.0.1:8000/file_parse",
|
api_url: str = "http://127.0.0.1:8000/file_parse",
|
||||||
|
image_processor: Optional[ImageProcessor] = None,
|
||||||
converter: Optional[Converter] = None,
|
converter: Optional[Converter] = None,
|
||||||
):
|
):
|
||||||
"""Initialize Local API service.
|
"""Initialize Local API service.
|
||||||
@@ -154,6 +234,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
converter: Optional converter instance for format conversion.
|
converter: Optional converter instance for format conversion.
|
||||||
"""
|
"""
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
|
self.image_processor = image_processor
|
||||||
self.converter = converter
|
self.converter = converter
|
||||||
|
|
||||||
def recognize(self, image: np.ndarray) -> dict:
|
def recognize(self, image: np.ndarray) -> dict:
|
||||||
@@ -166,6 +247,9 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
Dict with 'markdown', 'latex', 'mathml' keys.
|
Dict with 'markdown', 'latex', 'mathml' keys.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if self.image_processor:
|
||||||
|
image = self.image_processor.add_padding(image)
|
||||||
|
|
||||||
# Convert numpy array to image bytes
|
# Convert numpy array to image bytes
|
||||||
success, encoded_image = cv2.imencode('.png', image)
|
success, encoded_image = cv2.imencode('.png', image)
|
||||||
if not success:
|
if not success:
|
||||||
@@ -184,7 +268,6 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
'return_md': 'true',
|
'return_md': 'true',
|
||||||
'return_images': 'false',
|
'return_images': 'false',
|
||||||
'end_page_id': '99999',
|
'end_page_id': '99999',
|
||||||
'parse_method': 'auto',
|
|
||||||
'start_page_id': '0',
|
'start_page_id': '0',
|
||||||
'lang_list': 'en',
|
'lang_list': 'en',
|
||||||
'server_url': 'string',
|
'server_url': 'string',
|
||||||
@@ -193,6 +276,7 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
'table_enable': 'true',
|
'table_enable': 'true',
|
||||||
'response_format_zip': 'false',
|
'response_format_zip': 'false',
|
||||||
'formula_enable': 'true',
|
'formula_enable': 'true',
|
||||||
|
'parse_method': 'ocr'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make API request
|
# Make API request
|
||||||
@@ -211,6 +295,8 @@ class MineruOCRService(OCRServiceBase):
|
|||||||
markdown_content = ""
|
markdown_content = ""
|
||||||
if 'results' in result and 'image' in result['results']:
|
if 'results' in result and 'image' in result['results']:
|
||||||
markdown_content = result['results']['image'].get('md_content', '')
|
markdown_content = result['results']['image'].get('md_content', '')
|
||||||
|
|
||||||
|
# markdown_content = _postprocess_markdown(markdown_content)
|
||||||
|
|
||||||
# Convert to other formats if converter is available
|
# Convert to other formats if converter is available
|
||||||
latex = ""
|
latex = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user