feat: add mineru model
This commit is contained in:
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service
|
from app.core.dependencies import get_image_processor, get_layout_detector, get_ocr_service, get_mineru_ocr_service
|
||||||
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
from app.schemas.image import ImageOCRRequest, ImageOCRResponse
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
from app.services.layout_detector import LayoutDetector
|
from app.services.layout_detector import LayoutDetector
|
||||||
from app.services.ocr_service import OCRService
|
from app.services.ocr_service import OCRService, MineruOCRService
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -16,7 +16,8 @@ async def process_image_ocr(
|
|||||||
request: ImageOCRRequest,
|
request: ImageOCRRequest,
|
||||||
image_processor: ImageProcessor = Depends(get_image_processor),
|
image_processor: ImageProcessor = Depends(get_image_processor),
|
||||||
layout_detector: LayoutDetector = Depends(get_layout_detector),
|
layout_detector: LayoutDetector = Depends(get_layout_detector),
|
||||||
ocr_service: OCRService = Depends(get_ocr_service),
|
mineru_service: MineruOCRService = Depends(get_mineru_ocr_service),
|
||||||
|
paddle_service: OCRService = Depends(get_ocr_service),
|
||||||
) -> ImageOCRResponse:
|
) -> ImageOCRResponse:
|
||||||
"""Process an image and extract content as LaTeX, Markdown, and MathML.
|
"""Process an image and extract content as LaTeX, Markdown, and MathML.
|
||||||
|
|
||||||
@@ -35,7 +36,12 @@ async def process_image_ocr(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ocr_result = ocr_service.recognize(image)
|
if request.model_name == "mineru":
|
||||||
|
ocr_result = mineru_service.recognize(image)
|
||||||
|
elif request.model_name == "paddle":
|
||||||
|
ocr_result = paddle_service.recognize(image)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid model name")
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=503, detail=str(e))
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# PaddleOCR-VL Settings
|
# PaddleOCR-VL Settings
|
||||||
paddleocr_vl_url: str = "http://127.0.0.1:8000/v1"
|
paddleocr_vl_url: str = "http://127.0.0.1:8000/v1"
|
||||||
|
|
||||||
|
# MinerOCR Settings
|
||||||
|
miner_ocr_api_url: str = "http://127.0.0.1:8000/file_parse"
|
||||||
|
|
||||||
# Model Paths
|
# Model Paths
|
||||||
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2"
|
pp_doclayout_model_dir: Optional[str] = "/home/yoge/.cache/modelscope/hub/models/PaddlePaddle/PP-DocLayoutV2"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
from app.services.layout_detector import LayoutDetector
|
from app.services.layout_detector import LayoutDetector
|
||||||
from app.services.ocr_service import OCRService
|
from app.services.ocr_service import OCRService, MineruOCRService
|
||||||
from app.services.converter import Converter
|
from app.services.converter import Converter
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
@@ -45,3 +45,13 @@ def get_converter() -> Converter:
|
|||||||
"""Get a DOCX converter instance."""
|
"""Get a DOCX converter instance."""
|
||||||
return Converter()
|
return Converter()
|
||||||
|
|
||||||
|
|
||||||
|
def get_mineru_ocr_service() -> MineruOCRService:
|
||||||
|
"""Get a MinerOCR service instance."""
|
||||||
|
settings = get_settings()
|
||||||
|
api_url = getattr(settings, 'miner_ocr_api_url', 'http://127.0.0.1:8000/file_parse')
|
||||||
|
return MineruOCRService(
|
||||||
|
api_url=api_url,
|
||||||
|
converter=get_converter(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class ImageOCRRequest(BaseModel):
|
|||||||
|
|
||||||
image_url: str | None = Field(None, description="URL to fetch the image from")
|
image_url: str | None = Field(None, description="URL to fetch the image from")
|
||||||
image_base64: str | None = Field(None, description="Base64-encoded image data")
|
image_base64: str | None = Field(None, description="Base64-encoded image data")
|
||||||
|
model_name: str = Field("mineru", description="Name of the model to use for OCR")
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_input(self):
|
def validate_input(self):
|
||||||
|
|||||||
@@ -1,17 +1,26 @@
|
|||||||
"""PaddleOCR-VL client service for text and formula recognition."""
|
"""PaddleOCR-VL client service for text and formula recognition."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from paddleocr import PaddleOCRVL
|
from paddleocr import PaddleOCRVL
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.services.layout_detector import LayoutDetector
|
from app.services.layout_detector import LayoutDetector
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
from app.services.converter import Converter
|
from app.services.converter import Converter
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
|
class OCRServiceBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def recognize(self, image: np.ndarray) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
class OCRService:
|
|
||||||
|
class OCRService(OCRServiceBase):
|
||||||
"""Service for OCR using PaddleOCR-VL."""
|
"""Service for OCR using PaddleOCR-VL."""
|
||||||
|
|
||||||
_pipeline: Optional[PaddleOCRVL] = None
|
_pipeline: Optional[PaddleOCRVL] = None
|
||||||
@@ -50,7 +59,7 @@ class OCRService:
|
|||||||
)
|
)
|
||||||
return OCRService._pipeline
|
return OCRService._pipeline
|
||||||
|
|
||||||
def recognize_mixed(self, image: np.ndarray) -> dict:
|
def _recognize_mixed(self, image: np.ndarray) -> dict:
|
||||||
"""Recognize mixed content (text + formulas) using PP-DocLayoutV2.
|
"""Recognize mixed content (text + formulas) using PP-DocLayoutV2.
|
||||||
|
|
||||||
This mode uses PaddleOCR-VL with PP-DocLayoutV2 for document-aware
|
This mode uses PaddleOCR-VL with PP-DocLayoutV2 for document-aware
|
||||||
@@ -82,7 +91,7 @@ class OCRService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Mixed recognition failed: {e}") from e
|
raise RuntimeError(f"Mixed recognition failed: {e}") from e
|
||||||
|
|
||||||
def recognize_formula(self, image: np.ndarray) -> dict:
|
def _recognize_formula(self, image: np.ndarray) -> dict:
|
||||||
"""Recognize formula/math content using PaddleOCR-VL with prompt.
|
"""Recognize formula/math content using PaddleOCR-VL with prompt.
|
||||||
|
|
||||||
This mode uses PaddleOCR-VL directly with a formula recognition prompt.
|
This mode uses PaddleOCR-VL directly with a formula recognition prompt.
|
||||||
@@ -125,6 +134,109 @@ class OCRService:
|
|||||||
padded_image = self.image_processor.add_padding(image)
|
padded_image = self.image_processor.add_padding(image)
|
||||||
layout_info = self.layout_detector.detect(padded_image)
|
layout_info = self.layout_detector.detect(padded_image)
|
||||||
if layout_info.MixedRecognition:
|
if layout_info.MixedRecognition:
|
||||||
return self.recognize_mixed(image)
|
return self._recognize_mixed(image)
|
||||||
else:
|
else:
|
||||||
return self.recognize_formula(image)
|
return self._recognize_formula(image)
|
||||||
|
|
||||||
|
|
||||||
|
class MineruOCRService(OCRServiceBase):
|
||||||
|
"""Service for OCR using local file_parse API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_url: str = "http://127.0.0.1:8000/file_parse",
|
||||||
|
converter: Optional[Converter] = None,
|
||||||
|
):
|
||||||
|
"""Initialize Local API service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_url: URL of the local file_parse API endpoint.
|
||||||
|
converter: Optional converter instance for format conversion.
|
||||||
|
"""
|
||||||
|
self.api_url = api_url
|
||||||
|
self.converter = converter
|
||||||
|
|
||||||
|
def recognize(self, image: np.ndarray) -> dict:
|
||||||
|
"""Recognize content using local file_parse API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as numpy array in BGR format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'markdown', 'latex', 'mathml' keys.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert numpy array to image bytes
|
||||||
|
success, encoded_image = cv2.imencode('.png', image)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to encode image")
|
||||||
|
|
||||||
|
image_bytes = BytesIO(encoded_image.tobytes())
|
||||||
|
|
||||||
|
# Prepare multipart form data
|
||||||
|
files = {
|
||||||
|
'files': ('image.png', image_bytes, 'image/png')
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'return_middle_json': 'false',
|
||||||
|
'return_model_output': 'false',
|
||||||
|
'return_md': 'true',
|
||||||
|
'return_images': 'false',
|
||||||
|
'end_page_id': '99999',
|
||||||
|
'parse_method': 'auto',
|
||||||
|
'start_page_id': '0',
|
||||||
|
'lang_list': 'en',
|
||||||
|
'server_url': 'string',
|
||||||
|
'return_content_list': 'false',
|
||||||
|
'backend': 'hybrid-auto-engine',
|
||||||
|
'table_enable': 'true',
|
||||||
|
'response_format_zip': 'false',
|
||||||
|
'formula_enable': 'true',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make API request
|
||||||
|
response = requests.post(
|
||||||
|
self.api_url,
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
headers={'accept': 'application/json'},
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Extract markdown content from response
|
||||||
|
markdown_content = ""
|
||||||
|
if 'results' in result and 'image' in result['results']:
|
||||||
|
markdown_content = result['results']['image'].get('md_content', '')
|
||||||
|
|
||||||
|
# Convert to other formats if converter is available
|
||||||
|
latex = ""
|
||||||
|
mathml = ""
|
||||||
|
if self.converter and markdown_content:
|
||||||
|
convert_result = self.converter.convert_to_formats(markdown_content)
|
||||||
|
latex = convert_result.latex
|
||||||
|
mathml = convert_result.mathml
|
||||||
|
|
||||||
|
return {
|
||||||
|
"markdown": markdown_content,
|
||||||
|
"latex": latex,
|
||||||
|
"mathml": mathml,
|
||||||
|
}
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"Local API request failed: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Recognition failed: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mineru_service = MineruOCRService()
|
||||||
|
image = cv2.imread("test/complex_formula.png")
|
||||||
|
image_numpy = np.array(image)
|
||||||
|
ocr_result = mineru_service.recognize(image_numpy)
|
||||||
|
print(ocr_result)
|
||||||
Reference in New Issue
Block a user