feat add glm-ocr core
This commit is contained in:
@@ -1,19 +1,23 @@
|
||||
"""PaddleOCR-VL client service for text and formula recognition."""
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from app.core.config import get_settings
|
||||
from paddleocr import PaddleOCRVL
|
||||
from typing import Optional
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.converter import Converter
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from paddleocr import PaddleOCRVL
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.converter import Converter
|
||||
from app.services.glm_postprocess import GLMResultFormatter
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
|
||||
# Strategy: remove spaces before \ and between non-command chars,
|
||||
# but preserve the space after \command when followed by a non-\ char
|
||||
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
|
||||
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
|
||||
cleaned = re.sub(
|
||||
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
|
||||
) # remove space after non-letter non-\
|
||||
return f"{operator}{{{cleaned}}}"
|
||||
|
||||
# Match _{ ... } or ^{ ... }
|
||||
@@ -383,8 +389,8 @@ class OCRServiceBase(ABC):
|
||||
class OCRService(OCRServiceBase):
|
||||
"""Service for OCR using PaddleOCR-VL."""
|
||||
|
||||
_pipeline: Optional[PaddleOCRVL] = None
|
||||
_layout_detector: Optional[LayoutDetector] = None
|
||||
_pipeline: PaddleOCRVL | None = None
|
||||
_layout_detector: LayoutDetector | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -549,7 +555,15 @@ class GLMOCRService(OCRServiceBase):
|
||||
|
||||
# Call OpenAI-compatible API with formula recognition prompt
|
||||
prompt = "Formula Recognition:"
|
||||
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Don't catch exceptions here - let them propagate for fallback handling
|
||||
response = self.openai_client.chat.completions.create(
|
||||
@@ -596,10 +610,10 @@ class MineruOCRService(OCRServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str = "http://127.0.0.1:8000/file_parse",
|
||||
image_processor: Optional[ImageProcessor] = None,
|
||||
converter: Optional[Converter] = None,
|
||||
image_processor: ImageProcessor | None = None,
|
||||
converter: Converter | None = None,
|
||||
glm_ocr_url: str = "http://localhost:8002/v1",
|
||||
layout_detector: Optional[LayoutDetector] = None,
|
||||
layout_detector: LayoutDetector | None = None,
|
||||
):
|
||||
"""Initialize Local API service.
|
||||
|
||||
@@ -614,7 +628,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
self.glm_ocr_url = glm_ocr_url
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
|
||||
|
||||
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
|
||||
def _recognize_formula_with_paddleocr_vl(
|
||||
self, image: np.ndarray, prompt: str = "Formula Recognition:"
|
||||
) -> str:
|
||||
"""Recognize formula using PaddleOCR-VL API.
|
||||
|
||||
Args:
|
||||
@@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase):
|
||||
image_url = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
# Call OpenAI-compatible API
|
||||
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="glm-ocr",
|
||||
@@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
|
||||
|
||||
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
|
||||
def _extract_and_recognize_formulas(
|
||||
self, markdown_content: str, original_image: np.ndarray
|
||||
) -> str:
|
||||
"""Extract image references from markdown and recognize formulas.
|
||||
|
||||
Args:
|
||||
@@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase):
|
||||
}
|
||||
|
||||
# Make API request
|
||||
response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30)
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
files=files,
|
||||
data=data,
|
||||
headers={"accept": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
@@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase):
|
||||
markdown_content = result["results"]["image"].get("md_content", "")
|
||||
|
||||
if "
|
||||
markdown_content = self._extract_and_recognize_formulas(
|
||||
markdown_content, original_image
|
||||
)
|
||||
|
||||
# Apply postprocessing to fix OCR errors
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
@@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase):
|
||||
raise RuntimeError(f"Recognition failed: {e}") from e
|
||||
|
||||
|
||||
# Task-specific prompts (from GLM-OCR SDK config.yaml)
|
||||
_TASK_PROMPTS: dict[str, str] = {
|
||||
"text": "Text Recognition:",
|
||||
"formula": "Formula Recognition:",
|
||||
"table": "Table Recognition:",
|
||||
}
|
||||
_DEFAULT_PROMPT = (
|
||||
"Recognize the text in the image and output in Markdown format. "
|
||||
"Preserve the original layout (headings/paragraphs/tables/formulas). "
|
||||
"Do not fabricate content that does not exist in the image."
|
||||
)
|
||||
|
||||
|
||||
class GLMOCREndToEndService(OCRServiceBase):
|
||||
"""End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR.
|
||||
|
||||
Pipeline:
|
||||
1. Add padding (ImageProcessor)
|
||||
2. Detect layout regions (LayoutDetector → PP-DocLayoutV3)
|
||||
3. Crop each region and call vLLM with a task-specific prompt (parallel)
|
||||
4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags
|
||||
5. _postprocess_markdown: LaTeX math error correction
|
||||
6. Converter: markdown → latex/mathml/mml
|
||||
|
||||
This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vl_server_url: str,
|
||||
image_processor: ImageProcessor,
|
||||
converter: Converter,
|
||||
layout_detector: LayoutDetector,
|
||||
max_workers: int = 8,
|
||||
):
|
||||
self.vl_server_url = vl_server_url or settings.glm_ocr_url
|
||||
self.image_processor = image_processor
|
||||
self.converter = converter
|
||||
self.layout_detector = layout_detector
|
||||
self.max_workers = max_workers
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
|
||||
self._formatter = GLMResultFormatter()
|
||||
|
||||
def _encode_region(self, image: np.ndarray) -> str:
|
||||
"""Convert BGR numpy array to base64 JPEG string."""
|
||||
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_img = PILImage.fromarray(rgb)
|
||||
buf = BytesIO()
|
||||
pil_img.save(buf, format="JPEG")
|
||||
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
def _call_vllm(self, image: np.ndarray, prompt: str) -> str:
|
||||
"""Send image + prompt to vLLM and return raw content string."""
|
||||
img_b64 = self._encode_region(image)
|
||||
data_url = f"data:image/jpeg;base64,{img_b64}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="glm-ocr",
|
||||
messages=messages,
|
||||
temperature=0.01,
|
||||
max_tokens=settings.max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]:
|
||||
"""Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords."""
|
||||
x1, y1, x2, y2 = bbox
|
||||
return [
|
||||
int(x1 / img_w * 1000),
|
||||
int(y1 / img_h * 1000),
|
||||
int(x2 / img_w * 1000),
|
||||
int(y2 / img_h * 1000),
|
||||
]
|
||||
|
||||
def recognize(self, image: np.ndarray) -> dict:
|
||||
"""Full pipeline: padding → layout → per-region OCR → postprocess → markdown.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array in BGR format.
|
||||
|
||||
Returns:
|
||||
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
|
||||
"""
|
||||
# 1. Padding
|
||||
padded = self.image_processor.add_padding(image)
|
||||
img_h, img_w = padded.shape[:2]
|
||||
|
||||
# 2. Layout detection
|
||||
layout_info = self.layout_detector.detect(padded)
|
||||
|
||||
# 3. OCR: per-region (parallel) or full-image fallback
|
||||
if not layout_info.regions:
|
||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
||||
markdown_content = self._formatter._clean_content(raw_content)
|
||||
else:
|
||||
# Build task list for non-figure regions
|
||||
tasks = []
|
||||
for idx, region in enumerate(layout_info.regions):
|
||||
if region.type == "figure":
|
||||
continue
|
||||
x1, y1, x2, y2 = (int(c) for c in region.bbox)
|
||||
cropped = padded[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
continue
|
||||
prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT)
|
||||
tasks.append((idx, region, cropped, prompt))
|
||||
|
||||
if not tasks:
|
||||
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
|
||||
markdown_content = self._formatter._clean_content(raw_content)
|
||||
else:
|
||||
# Parallel OCR calls
|
||||
raw_results: dict[int, str] = {}
|
||||
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
|
||||
future_map = {
|
||||
ex.submit(self._call_vllm, cropped, prompt): idx
|
||||
for idx, region, cropped, prompt in tasks
|
||||
}
|
||||
for future in as_completed(future_map):
|
||||
idx = future_map[future]
|
||||
try:
|
||||
raw_results[idx] = future.result()
|
||||
except Exception:
|
||||
raw_results[idx] = ""
|
||||
|
||||
# Build structured region dicts for GLMResultFormatter
|
||||
region_dicts = []
|
||||
for idx, region, _cropped, _prompt in tasks:
|
||||
region_dicts.append(
|
||||
{
|
||||
"index": idx,
|
||||
"label": region.type,
|
||||
"native_label": region.native_label,
|
||||
"content": raw_results.get(idx, ""),
|
||||
"bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h),
|
||||
}
|
||||
)
|
||||
|
||||
# 4. GLM-OCR postprocessing: clean, format, merge, bullets
|
||||
markdown_content = self._formatter.process(region_dicts)
|
||||
|
||||
# 5. LaTeX math error correction (our existing pipeline)
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
|
||||
# 6. Format conversion
|
||||
latex, mathml, mml = "", "", ""
|
||||
if markdown_content and self.converter:
|
||||
fmt = self.converter.convert_to_formats(markdown_content)
|
||||
latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml
|
||||
|
||||
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mineru_service = MineruOCRService()
|
||||
image = cv2.imread("test/formula2.jpg")
|
||||
|
||||
Reference in New Issue
Block a user