From 6ea37c9380d5085f4b6cc922fca240f56ed998ef Mon Sep 17 00:00:00 2001 From: yogeliu Date: Mon, 5 Jan 2026 17:30:54 +0800 Subject: [PATCH] feat: add mineru model --- app/api/v1/endpoints/image.py | 14 ++-- app/core/config.py | 3 + app/core/dependencies.py | 12 +++- app/schemas/image.py | 1 + app/services/ocr_service.py | 122 ++++++++++++++++++++++++++++++++-- 5 files changed, 142 insertions(+), 10 deletions(-) diff --git a/app/api/v1/endpoints/image.py b/app/api/v1/endpoints/image.py index 2c2e141..e2e0c92 100644 --- a/app/api/v1/endpoints/image.py +++ b/app/api/v1/endpoints/image.py @@ -2,11 +2,11 @@ from fastapi import APIRouter, Depends, HTTPException -from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service +from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service, get_mineru_ocr_service from app.schemas.image import ImageOCRRequest, ImageOCRResponse from app.services.image_processor import ImageProcessor from app.services.layout_detector import LayoutDetector -from app.services.ocr_service import OCRService +from app.services.ocr_service import OCRService, MineruOCRService router = APIRouter() @@ -16,7 +16,8 @@ async def process_image_ocr( request: ImageOCRRequest, image_processor: ImageProcessor = Depends(get_image_processor), layout_detector: LayoutDetector = Depends(get_layout_detector), - ocr_service: OCRService = Depends(get_ocr_service), + mineru_service: MineruOCRService = Depends(get_mineru_ocr_service), + paddle_service: OCRService = Depends(get_ocr_service), ) -> ImageOCRResponse: """Process an image and extract content as LaTeX, Markdown, and MathML. @@ -35,7 +36,12 @@ async def process_image_ocr( ) try: - ocr_result = ocr_service.recognize(image) + if request.model_name == "mineru": + ocr_result = mineru_service.recognize(image) + elif request.model_name == "paddle": + ocr_result = paddle_service.recognize(image) + else: + raise HTTPException(status_code=400, detail="Invalid model name") except RuntimeError as e: raise HTTPException(status_code=503, detail=str(e)) diff --git a/app/core/config.py b/app/core/config.py index c3d81a7..6b33e14 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -23,6 +23,9 @@ class Settings(BaseSettings): # PaddleOCR-VL Settings paddleocr_vl_url: str = "http://127.0.0.1:8000/v1" + + # MinerOCR Settings + miner_ocr_api_url: str = "http://127.0.0.1:8000/file_parse" # Model Paths pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2" diff --git a/app/core/dependencies.py b/app/core/dependencies.py index ea19022..20d5a99 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -2,7 +2,7 @@ from app.services.image_processor import ImageProcessor from app.services.layout_detector import LayoutDetector -from app.services.ocr_service import OCRService +from app.services.ocr_service import OCRService, MineruOCRService from app.services.converter import Converter from app.core.config import get_settings @@ -45,3 +45,13 @@ def get_converter() -> Converter: """Get a DOCX converter instance.""" return Converter() + +def get_mineru_ocr_service() -> MineruOCRService: + """Get a MinerOCR service instance.""" + settings = get_settings() + api_url = getattr(settings, 'miner_ocr_api_url', 'http://127.0.0.1:8000/file_parse') + return MineruOCRService( + api_url=api_url, + converter=get_converter(), + ) + diff --git a/app/schemas/image.py b/app/schemas/image.py index 3378843..23be6d0 100644 --- a/app/schemas/image.py +++ b/app/schemas/image.py @@ -25,6 +25,7 @@ class ImageOCRRequest(BaseModel): image_url: str | None = Field(None, description="URL to fetch the image from") image_base64: str | None = Field(None, description="Base64-encoded image data") + model_name: str = Field("mineru", description="Name of the model to use for OCR") @model_validator(mode="after") def validate_input(self): diff --git a/app/services/ocr_service.py b/app/services/ocr_service.py index 741290b..ebfbf42 100644 --- a/app/services/ocr_service.py +++ b/app/services/ocr_service.py @@ -1,17 +1,26 @@ """PaddleOCR-VL client service for text and formula recognition.""" import numpy as np +import cv2 +import requests +from io import BytesIO from app.core.config import get_settings from paddleocr import PaddleOCRVL from typing import Optional from app.services.layout_detector import LayoutDetector from app.services.image_processor import ImageProcessor from app.services.converter import Converter +from abc import ABC, abstractmethod settings = get_settings() +class OCRServiceBase(ABC): + @abstractmethod + def recognize(self, image: np.ndarray) -> dict: + pass -class OCRService: + +class OCRService(OCRServiceBase): """Service for OCR using PaddleOCR-VL.""" _pipeline: Optional[PaddleOCRVL] = None @@ -50,7 +59,7 @@ class OCRService: ) return OCRService._pipeline - def recognize_mixed(self, image: np.ndarray) -> dict: + def _recognize_mixed(self, image: np.ndarray) -> dict: """Recognize mixed content (text + formulas) using PP-DocLayoutV2. This mode uses PaddleOCR-VL with PP-DocLayoutV2 for document-aware @@ -82,7 +91,7 @@ class OCRService: except Exception as e: raise RuntimeError(f"Mixed recognition failed: {e}") from e - def recognize_formula(self, image: np.ndarray) -> dict: + def _recognize_formula(self, image: np.ndarray) -> dict: """Recognize formula/math content using PaddleOCR-VL with prompt. This mode uses PaddleOCR-VL directly with a formula recognition prompt. @@ -125,6 +134,109 @@ class OCRService: padded_image = self.image_processor.add_padding(image) layout_info = self.layout_detector.detect(padded_image) if layout_info.MixedRecognition: - return self.recognize_mixed(image) + return self._recognize_mixed(image) else: - return self.recognize_formula(image) + return self._recognize_formula(image) + + +class MineruOCRService(OCRServiceBase): + """Service for OCR using local file_parse API.""" + + def __init__( + self, + api_url: str = "http://127.0.0.1:8000/file_parse", + converter: Optional[Converter] = None, + ): + """Initialize Local API service. + + Args: + api_url: URL of the local file_parse API endpoint. + converter: Optional converter instance for format conversion. + """ + self.api_url = api_url + self.converter = converter + + def recognize(self, image: np.ndarray) -> dict: + """Recognize content using local file_parse API. + + Args: + image: Input image as numpy array in BGR format. + + Returns: + Dict with 'markdown', 'latex', 'mathml' keys. + """ + try: + # Convert numpy array to image bytes + success, encoded_image = cv2.imencode('.png', image) + if not success: + raise RuntimeError("Failed to encode image") + + image_bytes = BytesIO(encoded_image.tobytes()) + + # Prepare multipart form data + files = { + 'files': ('image.png', image_bytes, 'image/png') + } + + data = { + 'return_middle_json': 'false', + 'return_model_output': 'false', + 'return_md': 'true', + 'return_images': 'false', + 'end_page_id': '99999', + 'parse_method': 'auto', + 'start_page_id': '0', + 'lang_list': 'en', + 'server_url': 'string', + 'return_content_list': 'false', + 'backend': 'hybrid-auto-engine', + 'table_enable': 'true', + 'response_format_zip': 'false', + 'formula_enable': 'true', + } + + # Make API request + response = requests.post( + self.api_url, + files=files, + data=data, + headers={'accept': 'application/json'}, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + + # Extract markdown content from response + markdown_content = "" + if 'results' in result and 'image' in result['results']: + markdown_content = result['results']['image'].get('md_content', '') + + # Convert to other formats if converter is available + latex = "" + mathml = "" + if self.converter and markdown_content: + convert_result = self.converter.convert_to_formats(markdown_content) + latex = convert_result.latex + mathml = convert_result.mathml + + return { + "markdown": markdown_content, + "latex": latex, + "mathml": mathml, + } + + except requests.RequestException as e: + raise RuntimeError(f"Local API request failed: {e}") from e + except Exception as e: + raise RuntimeError(f"Recognition failed: {e}") from e + + + + +if __name__ == "__main__": + mineru_service = MineruOCRService() + image = cv2.imread("test/complex_formula.png") + image_numpy = np.array(image) + ocr_result = mineru_service.recognize(image_numpy) + print(ocr_result) \ No newline at end of file