Compare commits

...

3 Commits

Author SHA1 Message Date
df2b664af4 fix: add image padding for mineru 2026-01-05 21:37:51 +08:00
6ea37c9380 feat: add mineru model 2026-01-05 17:30:54 +08:00
3870c108b2 fix: image alpha error 2026-01-01 23:38:52 +08:00
8 changed files with 296 additions and 36 deletions

View File

@@ -2,11 +2,11 @@
from fastapi import APIRouter, Depends, HTTPException
from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service
from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service, get_mineru_ocr_service
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService
from app.services.ocr_service import OCRService, MineruOCRService
router = APIRouter()
@@ -16,7 +16,8 @@ async def process_image_ocr(
request: ImageOCRRequest,
image_processor: ImageProcessor = Depends(get_image_processor),
layout_detector: LayoutDetector = Depends(get_layout_detector),
ocr_service: OCRService = Depends(get_ocr_service),
mineru_service: MineruOCRService = Depends(get_mineru_ocr_service),
paddle_service: OCRService = Depends(get_ocr_service),
) -> ImageOCRResponse:
"""Process an image and extract content as LaTeX, Markdown, and MathML.
@@ -35,12 +36,15 @@ async def process_image_ocr(
)
try:
# 3. Perform OCR based on layout
ocr_result = ocr_service.recognize(image)
if request.model_name == "mineru":
ocr_result = mineru_service.recognize(image)
elif request.model_name == "paddle":
ocr_result = paddle_service.recognize(image)
else:
raise HTTPException(status_code=400, detail="Invalid model name")
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
# 4. Return response
return ImageOCRResponse(
latex=ocr_result.get("latex", ""),
markdown=ocr_result.get("markdown", ""),

View File

@@ -23,6 +23,9 @@ class Settings(BaseSettings):
# PaddleOCR-VL Settings
paddleocr_vl_url: str = "http://127.0.0.1:8000/v1"
# MinerOCR Settings
miner_ocr_api_url: str = "http://127.0.0.1:8000/file_parse"
# Model Paths
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2"

View File

@@ -2,7 +2,7 @@
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
from app.services.ocr_service import OCRService
from app.services.ocr_service import OCRService, MineruOCRService
from app.services.converter import Converter
from app.core.config import get_settings
@@ -45,3 +45,14 @@ def get_converter() -> Converter:
"""Get a DOCX converter instance."""
return Converter()
def get_mineru_ocr_service() -> MineruOCRService:
"""Get a MinerOCR service instance."""
settings = get_settings()
api_url = getattr(settings, 'miner_ocr_api_url', 'http://127.0.0.1:8000/file_parse')
return MineruOCRService(
api_url=api_url,
converter=get_converter(),
image_processor=get_image_processor(),
)

View File

@@ -33,6 +33,7 @@ app = FastAPI(
app.include_router(api_router, prefix=settings.api_prefix)
@app.get("/health")
async def health_check():
"""Health check endpoint."""

View File

@@ -25,6 +25,7 @@ class ImageOCRRequest(BaseModel):
image_url: str | None = Field(None, description="URL to fetch the image from")
image_base64: str | None = Field(None, description="Base64-encoded image data")
model_name: str = Field("mineru", description="Name of the model to use for OCR")
@model_validator(mode="after")
def validate_input(self):

View File

@@ -25,6 +25,38 @@ class ImageProcessor:
"""
self.padding_ratio = padding_ratio or settings.image_padding_ratio
def _convert_to_bgr(self, pil_image: Image.Image) -> np.ndarray:
"""Convert PIL Image to BGR numpy array, handling alpha channel.
Args:
pil_image: PIL Image object.
Returns:
Image as numpy array in BGR format.
"""
# Handle RGBA images (PNG with transparency)
if pil_image.mode == "RGBA":
# Create white background and paste image on top
background = Image.new("RGB", pil_image.size, (255, 255, 255))
background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha as mask
pil_image = background
elif pil_image.mode == "LA":
# Grayscale with alpha
background = Image.new("L", pil_image.size, 255)
background.paste(pil_image, mask=pil_image.split()[1])
pil_image = background.convert("RGB")
elif pil_image.mode == "P":
# Palette mode, may have transparency
pil_image = pil_image.convert("RGBA")
background = Image.new("RGB", pil_image.size, (255, 255, 255))
background.paste(pil_image, mask=pil_image.split()[3])
pil_image = background
elif pil_image.mode != "RGB":
# Convert other modes to RGB
pil_image = pil_image.convert("RGB")
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
def load_image_from_url(self, url: str) -> np.ndarray:
"""Load image from URL.
@@ -40,8 +72,8 @@ class ImageProcessor:
try:
with urlopen(url, timeout=30) as response:
image_data = response.read()
image = Image.open(io.BytesIO(image_data))
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
pil_image = Image.open(io.BytesIO(image_data))
return self._convert_to_bgr(pil_image)
except Exception as e:
raise ValueError(f"Failed to load image from URL: {e}") from e
@@ -63,8 +95,8 @@ class ImageProcessor:
base64_str = base64_str.split(",", 1)[1]
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
pil_image = Image.open(io.BytesIO(image_data))
return self._convert_to_bgr(pil_image)
except Exception as e:
raise ValueError(f"Failed to decode base64 image: {e}") from e

View File

@@ -140,18 +140,39 @@ class LayoutDetector:
if __name__ == "__main__":
import cv2
from app.core.config import get_settings
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
from app.services.ocr_service import OCRService
settings = get_settings()
# Initialize dependencies
layout_detector = LayoutDetector()
image_path = "test/timeout.png"
image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio)
converter = Converter()
# Initialize OCR service
ocr_service = OCRService(
vl_server_url=settings.paddleocr_vl_url,
layout_detector=layout_detector,
image_processor=image_processor,
converter=converter,
)
# Load test image
image_path = "test/complex_formula.png"
image = cv2.imread(image_path)
image_processor = ImageProcessor(padding_ratio=0.15)
image = image_processor.add_padding(image)
# Save the padded image for debugging
cv2.imwrite("debug_padded_image.png", image)
layout_info = layout_detector.detect(image)
print(layout_info)
if image is None:
print(f"Failed to load image: {image_path}")
else:
print(f"Image loaded: {image.shape}")
# Run OCR recognition
result = ocr_service.recognize(image)
print("\n=== OCR Result ===")
print(f"Markdown:\n{result['markdown']}")
print(f"\nLaTeX:\n{result['latex']}")
print(f"\nMathML:\n{result['mathml']}")

View File

@@ -1,17 +1,103 @@
"""PaddleOCR-VL client service for text and formula recognition."""
import re
import numpy as np
import cv2
import requests
from io import BytesIO
from app.core.config import get_settings
from paddleocr import PaddleOCRVL
from typing import Optional
from app.services.layout_detector import LayoutDetector
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
from abc import ABC, abstractmethod
settings = get_settings()
_COMMANDS_NEED_SPACE = {
# operators / calculus
"cdot", "times", "div", "pm", "mp",
"int", "iint", "iiint", "oint", "sum", "prod", "lim",
# common functions
"sin", "cos", "tan", "cot", "sec", "csc",
"log", "ln", "exp",
# misc
"partial", "nabla",
}
class OCRService:
_MATH_SEGMENT_PATTERN = re.compile(r"\$\$.*?\$\$|\$.*?\$", re.DOTALL)
_COMMAND_TOKEN_PATTERN = re.compile(r"\\[a-zA-Z]+")
# stage2: differentials inside math segments
_DIFFERENTIAL_UPPER_PATTERN = re.compile(r"(?<!\\)d([A-Z])")
_DIFFERENTIAL_LOWER_PATTERN = re.compile(r"(?<!\\)d([a-z])")
def _split_glued_command_token(token: str) -> str:
"""Split OCR-glued LaTeX command token by whitelist longest-prefix.
Examples:
- \\cdotdS -> \\cdot dS
- \\intdx -> \\int dx
"""
if not token.startswith("\\"):
return token
body = token[1:]
if len(body) < 2:
return token
best = None
# longest prefix that is in whitelist
for i in range(1, len(body)):
prefix = body[:i]
if prefix in _COMMANDS_NEED_SPACE:
best = prefix
if not best:
return token
suffix = body[len(best):]
if not suffix:
return token
return f"\\{best} {suffix}"
def _postprocess_math(expr: str) -> str:
"""Postprocess a *math* expression (already inside $...$ or $$...$$)."""
# stage1: split glued command tokens (e.g. \cdotdS)
expr = _COMMAND_TOKEN_PATTERN.sub(lambda m: _split_glued_command_token(m.group(0)), expr)
# stage2: normalize differentials (keep conservative)
expr = _DIFFERENTIAL_UPPER_PATTERN.sub(r"\\mathrm{d} \1", expr)
expr = _DIFFERENTIAL_LOWER_PATTERN.sub(r"d \1", expr)
return expr
def _postprocess_markdown(markdown_content: str) -> str:
"""Apply LaTeX postprocessing only within $...$ / $$...$$ segments."""
if not markdown_content:
return markdown_content
def _fix_segment(m: re.Match) -> str:
seg = m.group(0)
if seg.startswith("$$") and seg.endswith("$$"):
return f"$${_postprocess_math(seg[2:-2])}$$"
if seg.startswith("$") and seg.endswith("$"):
return f"${_postprocess_math(seg[1:-1])}$"
return seg
return _MATH_SEGMENT_PATTERN.sub(_fix_segment, markdown_content)
class OCRServiceBase(ABC):
@abstractmethod
def recognize(self, image: np.ndarray) -> dict:
pass
class OCRService(OCRServiceBase):
"""Service for OCR using PaddleOCR-VL."""
_pipeline: Optional[PaddleOCRVL] = None
@@ -35,6 +121,7 @@ class OCRService:
self.layout_detector = layout_detector
self.image_processor = image_processor
self.converter = converter
def _get_pipeline(self):
"""Get or create PaddleOCR-VL pipeline.
@@ -49,7 +136,7 @@ class OCRService:
)
return OCRService._pipeline
def recognize_mixed(self, image: np.ndarray) -> dict:
def _recognize_mixed(self, image: np.ndarray) -> dict:
"""Recognize mixed content (text + formulas) using PP-DocLayoutV2.
This mode uses PaddleOCR-VL with PP-DocLayoutV2 for document-aware
@@ -71,6 +158,7 @@ class OCRService:
for res in output:
markdown_content += res.markdown.get("markdown_texts", "")
markdown_content = _postprocess_markdown(markdown_content)
convert_result = self.converter.convert_to_formats(markdown_content)
return {
@@ -81,7 +169,7 @@ class OCRService:
except Exception as e:
raise RuntimeError(f"Mixed recognition failed: {e}") from e
def recognize_formula(self, image: np.ndarray) -> dict:
def _recognize_formula(self, image: np.ndarray) -> dict:
"""Recognize formula/math content using PaddleOCR-VL with prompt.
This mode uses PaddleOCR-VL directly with a formula recognition prompt.
@@ -102,6 +190,7 @@ class OCRService:
for res in output:
markdown_content += res.markdown.get("markdown_texts", "")
markdown_content = _postprocess_markdown(markdown_content)
convert_result = self.converter.convert_to_formats(markdown_content)
return {
@@ -124,18 +213,116 @@ class OCRService:
padded_image = self.image_processor.add_padding(image)
layout_info = self.layout_detector.detect(padded_image)
if layout_info.MixedRecognition:
return self.recognize_mixed(image)
return self._recognize_mixed(image)
else:
return self.recognize_formula(image)
return self._recognize_formula(image)
class MineruOCRService(OCRServiceBase):
"""Service for OCR using local file_parse API."""
def __init__(
self,
api_url: str = "http://127.0.0.1:8000/file_parse",
image_processor: Optional[ImageProcessor] = None,
converter: Optional[Converter] = None,
):
"""Initialize Local API service.
Args:
api_url: URL of the local file_parse API endpoint.
converter: Optional converter instance for format conversion.
"""
self.api_url = api_url
self.image_processor = image_processor
self.converter = converter
def recognize(self, image: np.ndarray) -> dict:
"""Recognize content using local file_parse API.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml' keys.
"""
try:
if self.image_processor:
image = self.image_processor.add_padding(image)
# Convert numpy array to image bytes
success, encoded_image = cv2.imencode('.png', image)
if not success:
raise RuntimeError("Failed to encode image")
image_bytes = BytesIO(encoded_image.tobytes())
# Prepare multipart form data
files = {
'files': ('image.png', image_bytes, 'image/png')
}
data = {
'return_middle_json': 'false',
'return_model_output': 'false',
'return_md': 'true',
'return_images': 'false',
'end_page_id': '99999',
'start_page_id': '0',
'lang_list': 'en',
'server_url': 'string',
'return_content_list': 'false',
'backend': 'hybrid-auto-engine',
'table_enable': 'true',
'response_format_zip': 'false',
'formula_enable': 'true',
'parse_method': 'ocr'
}
# Make API request
response = requests.post(
self.api_url,
files=files,
data=data,
headers={'accept': 'application/json'},
timeout=30
)
response.raise_for_status()
result = response.json()
# Extract markdown content from response
markdown_content = ""
if 'results' in result and 'image' in result['results']:
markdown_content = result['results']['image'].get('md_content', '')
# markdown_content = _postprocess_markdown(markdown_content)
# Convert to other formats if converter is available
latex = ""
mathml = ""
if self.converter and markdown_content:
convert_result = self.converter.convert_to_formats(markdown_content)
latex = convert_result.latex
mathml = convert_result.mathml
return {
"markdown": markdown_content,
"latex": latex,
"mathml": mathml,
}
except requests.RequestException as e:
raise RuntimeError(f"Local API request failed: {e}") from e
except Exception as e:
raise RuntimeError(f"Recognition failed: {e}") from e
if __name__ == "__main__":
import cv2
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
image_processor = ImageProcessor(padding_ratio=0.15)
layout_detector = LayoutDetector()
ocr_service = OCRService(image_processor=image_processor, layout_detector=layout_detector)
image = cv2.imread("test/image.png")
ocr_result = ocr_service.recognize(image)
mineru_service = MineruOCRService()
image = cv2.imread("test/complex_formula.png")
image_numpy = np.array(image)
ocr_result = mineru_service.recognize(image_numpy)
print(ocr_result)