Files
doc_processer/app/services/ocr_service.py

328 lines
10 KiB
Python
Raw Normal View History

2025-12-29 17:34:58 +08:00
"""PaddleOCR-VL client service for text and formula recognition."""
2026-01-05 21:37:51 +08:00
import re
2025-12-29 17:34:58 +08:00
import numpy as np
2026-01-05 17:30:54 +08:00
import cv2
import requests
from io import BytesIO
2025-12-29 17:34:58 +08:00
from app.core.config import get_settings
2025-12-31 17:38:32 +08:00
from paddleocr import PaddleOCRVL
from typing import Optional
from app.services.layout_detector import LayoutDetector
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
2026-01-05 17:30:54 +08:00
from abc import ABC, abstractmethod
2025-12-29 17:34:58 +08:00
settings = get_settings()
2026-01-05 21:37:51 +08:00
_COMMANDS_NEED_SPACE = {
# operators / calculus
"cdot", "times", "div", "pm", "mp",
"int", "iint", "iiint", "oint", "sum", "prod", "lim",
# common functions
"sin", "cos", "tan", "cot", "sec", "csc",
"log", "ln", "exp",
# misc
"partial", "nabla",
}
_MATH_SEGMENT_PATTERN = re.compile(r"\$\$.*?\$\$|\$.*?\$", re.DOTALL)
_COMMAND_TOKEN_PATTERN = re.compile(r"\\[a-zA-Z]+")
# stage2: differentials inside math segments
_DIFFERENTIAL_UPPER_PATTERN = re.compile(r"(?<!\\)d([A-Z])")
_DIFFERENTIAL_LOWER_PATTERN = re.compile(r"(?<!\\)d([a-z])")
def _split_glued_command_token(token: str) -> str:
"""Split OCR-glued LaTeX command token by whitelist longest-prefix.
Examples:
- \\cdotdS -> \\cdot dS
- \\intdx -> \\int dx
"""
if not token.startswith("\\"):
return token
body = token[1:]
if len(body) < 2:
return token
best = None
# longest prefix that is in whitelist
for i in range(1, len(body)):
prefix = body[:i]
if prefix in _COMMANDS_NEED_SPACE:
best = prefix
if not best:
return token
suffix = body[len(best):]
if not suffix:
return token
return f"\\{best} {suffix}"
def _postprocess_math(expr: str) -> str:
"""Postprocess a *math* expression (already inside $...$ or $$...$$)."""
# stage1: split glued command tokens (e.g. \cdotdS)
expr = _COMMAND_TOKEN_PATTERN.sub(lambda m: _split_glued_command_token(m.group(0)), expr)
# stage2: normalize differentials (keep conservative)
expr = _DIFFERENTIAL_UPPER_PATTERN.sub(r"\\mathrm{d} \1", expr)
expr = _DIFFERENTIAL_LOWER_PATTERN.sub(r"d \1", expr)
return expr
def _postprocess_markdown(markdown_content: str) -> str:
"""Apply LaTeX postprocessing only within $...$ / $$...$$ segments."""
if not markdown_content:
return markdown_content
def _fix_segment(m: re.Match) -> str:
seg = m.group(0)
if seg.startswith("$$") and seg.endswith("$$"):
return f"$${_postprocess_math(seg[2:-2])}$$"
if seg.startswith("$") and seg.endswith("$"):
return f"${_postprocess_math(seg[1:-1])}$"
return seg
return _MATH_SEGMENT_PATTERN.sub(_fix_segment, markdown_content)
2026-01-05 17:30:54 +08:00
class OCRServiceBase(ABC):
@abstractmethod
def recognize(self, image: np.ndarray) -> dict:
pass
2025-12-29 17:34:58 +08:00
2026-01-05 17:30:54 +08:00
class OCRService(OCRServiceBase):
2025-12-29 17:34:58 +08:00
"""Service for OCR using PaddleOCR-VL."""
2025-12-31 17:38:32 +08:00
_pipeline: Optional[PaddleOCRVL] = None
_layout_detector: Optional[LayoutDetector] = None
2025-12-29 17:34:58 +08:00
def __init__(
self,
2025-12-31 17:38:32 +08:00
vl_server_url: str,
layout_detector: LayoutDetector,
image_processor: ImageProcessor,
converter: Converter,
2025-12-29 17:34:58 +08:00
):
"""Initialize OCR service.
Args:
vl_server_url: URL of the vLLM server for PaddleOCR-VL.
2025-12-31 17:38:32 +08:00
layout_detector: Layout detector instance.
image_processor: Image processor instance.
2025-12-29 17:34:58 +08:00
"""
self.vl_server_url = vl_server_url or settings.paddleocr_vl_url
2025-12-31 17:38:32 +08:00
self.layout_detector = layout_detector
self.image_processor = image_processor
self.converter = converter
2026-01-01 23:38:52 +08:00
2025-12-31 17:38:32 +08:00
def _get_pipeline(self):
2025-12-29 17:34:58 +08:00
"""Get or create PaddleOCR-VL pipeline.
Returns:
PaddleOCRVL pipeline instance.
"""
2025-12-31 17:38:32 +08:00
if OCRService._pipeline is None:
OCRService._pipeline = PaddleOCRVL(
2025-12-29 17:34:58 +08:00
vl_rec_backend="vllm-server",
vl_rec_server_url=self.vl_server_url,
layout_detection_model_name="PP-DocLayoutV2",
)
2025-12-31 17:38:32 +08:00
return OCRService._pipeline
2025-12-29 17:34:58 +08:00
2026-01-05 17:30:54 +08:00
def _recognize_mixed(self, image: np.ndarray) -> dict:
2025-12-29 17:34:58 +08:00
"""Recognize mixed content (text + formulas) using PP-DocLayoutV2.
This mode uses PaddleOCR-VL with PP-DocLayoutV2 for document-aware
recognition of mixed content.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml' keys.
"""
try:
pipeline = self._get_pipeline()
2025-12-31 17:38:32 +08:00
output = pipeline.predict(image, use_layout_detection=True)
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
markdown_content = ""
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
for res in output:
markdown_content += res.markdown.get("markdown_texts", "")
2025-12-29 17:34:58 +08:00
2026-01-05 21:37:51 +08:00
markdown_content = _postprocess_markdown(markdown_content)
2025-12-31 17:38:32 +08:00
convert_result = self.converter.convert_to_formats(markdown_content)
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
return {
"markdown": markdown_content,
"latex": convert_result.latex,
"mathml": convert_result.mathml,
}
2025-12-29 17:34:58 +08:00
except Exception as e:
raise RuntimeError(f"Mixed recognition failed: {e}") from e
2026-01-05 17:30:54 +08:00
def _recognize_formula(self, image: np.ndarray) -> dict:
2025-12-29 17:34:58 +08:00
"""Recognize formula/math content using PaddleOCR-VL with prompt.
This mode uses PaddleOCR-VL directly with a formula recognition prompt.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'latex', 'markdown', 'mathml' keys.
"""
try:
2025-12-31 17:38:32 +08:00
pipeline = self._get_pipeline()
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
output = pipeline.predict(image, use_layout_detection=False, prompt_label="formula")
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
markdown_content = ""
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
for res in output:
markdown_content += res.markdown.get("markdown_texts", "")
2025-12-29 17:34:58 +08:00
2026-01-05 21:37:51 +08:00
markdown_content = _postprocess_markdown(markdown_content)
2025-12-31 17:38:32 +08:00
convert_result = self.converter.convert_to_formats(markdown_content)
2025-12-29 17:34:58 +08:00
2025-12-31 17:38:32 +08:00
return {
"latex": convert_result.latex,
"mathml": convert_result.mathml,
"markdown": markdown_content,
}
2025-12-29 17:34:58 +08:00
except Exception as e:
raise RuntimeError(f"Formula recognition failed: {e}") from e
2025-12-31 17:38:32 +08:00
def recognize(self, image: np.ndarray) -> dict:
"""Recognize content using PaddleOCR-VL.
2025-12-29 17:34:58 +08:00
Args:
image: Input image as numpy array in BGR format.
Returns:
2025-12-31 17:38:32 +08:00
Dict with 'latex', 'markdown', 'mathml' keys.
2025-12-29 17:34:58 +08:00
"""
2025-12-31 17:38:32 +08:00
padded_image = self.image_processor.add_padding(image)
layout_info = self.layout_detector.detect(padded_image)
if layout_info.MixedRecognition:
2026-01-05 17:30:54 +08:00
return self._recognize_mixed(image)
2025-12-29 17:34:58 +08:00
else:
2026-01-05 17:30:54 +08:00
return self._recognize_formula(image)
class MineruOCRService(OCRServiceBase):
"""Service for OCR using local file_parse API."""
def __init__(
self,
api_url: str = "http://127.0.0.1:8000/file_parse",
2026-01-05 21:37:51 +08:00
image_processor: Optional[ImageProcessor] = None,
2026-01-05 17:30:54 +08:00
converter: Optional[Converter] = None,
):
"""Initialize Local API service.
Args:
api_url: URL of the local file_parse API endpoint.
converter: Optional converter instance for format conversion.
"""
self.api_url = api_url
2026-01-05 21:37:51 +08:00
self.image_processor = image_processor
2026-01-05 17:30:54 +08:00
self.converter = converter
def recognize(self, image: np.ndarray) -> dict:
"""Recognize content using local file_parse API.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml' keys.
"""
try:
2026-01-05 21:37:51 +08:00
if self.image_processor:
image = self.image_processor.add_padding(image)
2026-01-05 17:30:54 +08:00
# Convert numpy array to image bytes
success, encoded_image = cv2.imencode('.png', image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
# Prepare multipart form data
files = {
'files': ('image.png', image_bytes, 'image/png')
}
data = {
'return_middle_json': 'false',
'return_model_output': 'false',
'return_md': 'true',
'return_images': 'false',
'end_page_id': '99999',
'start_page_id': '0',
'lang_list': 'en',
'server_url': 'string',
'return_content_list': 'false',
'backend': 'hybrid-auto-engine',
'table_enable': 'true',
'response_format_zip': 'false',
'formula_enable': 'true',
2026-01-05 21:37:51 +08:00
'parse_method': 'ocr'
2026-01-05 17:30:54 +08:00
}
# Make API request
response = requests.post(
self.api_url,
files=files,
data=data,
headers={'accept': 'application/json'},
timeout=30
)
response.raise_for_status()
result = response.json()
# Extract markdown content from response
markdown_content = ""
if 'results' in result and 'image' in result['results']:
markdown_content = result['results']['image'].get('md_content', '')
2026-01-05 21:37:51 +08:00
# markdown_content = _postprocess_markdown(markdown_content)
2026-01-05 17:30:54 +08:00
# Convert to other formats if converter is available
latex = ""
mathml = ""
if self.converter and markdown_content:
convert_result = self.converter.convert_to_formats(markdown_content)
latex = convert_result.latex
mathml = convert_result.mathml
return {
"markdown": markdown_content,
"latex": latex,
"mathml": mathml,
}
except requests.RequestException as e:
raise RuntimeError(f"Local API request failed: {e}") from e
except Exception as e:
raise RuntimeError(f"Recognition failed: {e}") from e
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/complex_formula.png")
image_numpy = np.array(image)
ocr_result = mineru_service.recognize(image_numpy)
print(ocr_result)