feat add glm-ocr core

This commit is contained in:
liuyuanchuang
2026-03-09 16:51:06 +08:00
parent d74130914c
commit 6dfaf9668b
17 changed files with 1687 additions and 140 deletions

View File

@@ -1,19 +1,23 @@
"""PaddleOCR-VL client service for text and formula recognition."""
import re
import numpy as np
import cv2
import requests
from io import BytesIO
import base64
from app.core.config import get_settings
from paddleocr import PaddleOCRVL
from typing import Optional
from app.services.layout_detector import LayoutDetector
from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
import re
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
import cv2
import numpy as np
import requests
from openai import OpenAI
from paddleocr import PaddleOCRVL
from PIL import Image as PILImage
from app.core.config import get_settings
from app.services.converter import Converter
from app.services.glm_postprocess import GLMResultFormatter
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
settings = get_settings()
@@ -144,7 +148,9 @@ def _clean_latex_syntax_spaces(expr: str) -> str:
# Strategy: remove spaces before \ and between non-command chars,
# but preserve the space after \command when followed by a non-\ char
cleaned = re.sub(r"\s+(?=\\)", "", content) # remove space before \cmd
cleaned = re.sub(r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned) # remove space after non-letter non-\
cleaned = re.sub(
r"(?<!\\)(?<![a-zA-Z])\s+", "", cleaned
) # remove space after non-letter non-\
return f"{operator}{{{cleaned}}}"
# Match _{ ... } or ^{ ... }
@@ -383,8 +389,8 @@ class OCRServiceBase(ABC):
class OCRService(OCRServiceBase):
"""Service for OCR using PaddleOCR-VL."""
_pipeline: Optional[PaddleOCRVL] = None
_layout_detector: Optional[LayoutDetector] = None
_pipeline: PaddleOCRVL | None = None
_layout_detector: LayoutDetector | None = None
def __init__(
self,
@@ -549,7 +555,15 @@ class GLMOCRService(OCRServiceBase):
# Call OpenAI-compatible API with formula recognition prompt
prompt = "Formula Recognition:"
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
# Don't catch exceptions here - let them propagate for fallback handling
response = self.openai_client.chat.completions.create(
@@ -596,10 +610,10 @@ class MineruOCRService(OCRServiceBase):
def __init__(
self,
api_url: str = "http://127.0.0.1:8000/file_parse",
image_processor: Optional[ImageProcessor] = None,
converter: Optional[Converter] = None,
image_processor: ImageProcessor | None = None,
converter: Converter | None = None,
glm_ocr_url: str = "http://localhost:8002/v1",
layout_detector: Optional[LayoutDetector] = None,
layout_detector: LayoutDetector | None = None,
):
"""Initialize Local API service.
@@ -614,7 +628,9 @@ class MineruOCRService(OCRServiceBase):
self.glm_ocr_url = glm_ocr_url
self.openai_client = OpenAI(api_key="EMPTY", base_url=glm_ocr_url, timeout=3600)
def _recognize_formula_with_paddleocr_vl(self, image: np.ndarray, prompt: str = "Formula Recognition:") -> str:
def _recognize_formula_with_paddleocr_vl(
self, image: np.ndarray, prompt: str = "Formula Recognition:"
) -> str:
"""Recognize formula using PaddleOCR-VL API.
Args:
@@ -634,7 +650,15 @@ class MineruOCRService(OCRServiceBase):
image_url = f"data:image/png;base64,{image_base64}"
# Call OpenAI-compatible API
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
@@ -647,7 +671,9 @@ class MineruOCRService(OCRServiceBase):
except Exception as e:
raise RuntimeError(f"PaddleOCR-VL formula recognition failed: {e}") from e
def _extract_and_recognize_formulas(self, markdown_content: str, original_image: np.ndarray) -> str:
def _extract_and_recognize_formulas(
self, markdown_content: str, original_image: np.ndarray
) -> str:
"""Extract image references from markdown and recognize formulas.
Args:
@@ -712,7 +738,13 @@ class MineruOCRService(OCRServiceBase):
}
# Make API request
response = requests.post(self.api_url, files=files, data=data, headers={"accept": "application/json"}, timeout=30)
response = requests.post(
self.api_url,
files=files,
data=data,
headers={"accept": "application/json"},
timeout=30,
)
response.raise_for_status()
result = response.json()
@@ -723,7 +755,9 @@ class MineruOCRService(OCRServiceBase):
markdown_content = result["results"]["image"].get("md_content", "")
if "![](images/" in markdown_content:
markdown_content = self._extract_and_recognize_formulas(markdown_content, original_image)
markdown_content = self._extract_and_recognize_formulas(
markdown_content, original_image
)
# Apply postprocessing to fix OCR errors
markdown_content = _postprocess_markdown(markdown_content)
@@ -751,6 +785,167 @@ class MineruOCRService(OCRServiceBase):
raise RuntimeError(f"Recognition failed: {e}") from e
# Task-specific prompts (from GLM-OCR SDK config.yaml)
_TASK_PROMPTS: dict[str, str] = {
"text": "Text Recognition:",
"formula": "Formula Recognition:",
"table": "Table Recognition:",
}
_DEFAULT_PROMPT = (
"Recognize the text in the image and output in Markdown format. "
"Preserve the original layout (headings/paragraphs/tables/formulas). "
"Do not fabricate content that does not exist in the image."
)
class GLMOCREndToEndService(OCRServiceBase):
"""End-to-end OCR using GLM-OCR pipeline: layout detection → per-region OCR.
Pipeline:
1. Add padding (ImageProcessor)
2. Detect layout regions (LayoutDetector → PP-DocLayoutV3)
3. Crop each region and call vLLM with a task-specific prompt (parallel)
4. GLMResultFormatter: clean, format titles/bullets/formulas, merge tags
5. _postprocess_markdown: LaTeX math error correction
6. Converter: markdown → latex/mathml/mml
This replaces both GLMOCRService (formula-only) and MineruOCRService (mixed).
"""
def __init__(
self,
vl_server_url: str,
image_processor: ImageProcessor,
converter: Converter,
layout_detector: LayoutDetector,
max_workers: int = 8,
):
self.vl_server_url = vl_server_url or settings.glm_ocr_url
self.image_processor = image_processor
self.converter = converter
self.layout_detector = layout_detector
self.max_workers = max_workers
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
self._formatter = GLMResultFormatter()
def _encode_region(self, image: np.ndarray) -> str:
"""Convert BGR numpy array to base64 JPEG string."""
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_img = PILImage.fromarray(rgb)
buf = BytesIO()
pil_img.save(buf, format="JPEG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _call_vllm(self, image: np.ndarray, prompt: str) -> str:
"""Send image + prompt to vLLM and return raw content string."""
img_b64 = self._encode_region(image)
data_url = f"data:image/jpeg;base64,{img_b64}"
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
]
response = self.openai_client.chat.completions.create(
model="glm-ocr",
messages=messages,
temperature=0.01,
max_tokens=settings.max_tokens,
)
return response.choices[0].message.content.strip()
def _normalize_bbox(self, bbox: list[float], img_w: int, img_h: int) -> list[int]:
"""Convert pixel bbox [x1,y1,x2,y2] to 0-1000 normalised coords."""
x1, y1, x2, y2 = bbox
return [
int(x1 / img_w * 1000),
int(y1 / img_h * 1000),
int(x2 / img_w * 1000),
int(y2 / img_h * 1000),
]
def recognize(self, image: np.ndarray) -> dict:
"""Full pipeline: padding → layout → per-region OCR → postprocess → markdown.
Args:
image: Input image as numpy array in BGR format.
Returns:
Dict with 'markdown', 'latex', 'mathml', 'mml' keys.
"""
# 1. Padding
padded = self.image_processor.add_padding(image)
img_h, img_w = padded.shape[:2]
# 2. Layout detection
layout_info = self.layout_detector.detect(padded)
# 3. OCR: per-region (parallel) or full-image fallback
if not layout_info.regions:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Build task list for non-figure regions
tasks = []
for idx, region in enumerate(layout_info.regions):
if region.type == "figure":
continue
x1, y1, x2, y2 = (int(c) for c in region.bbox)
cropped = padded[y1:y2, x1:x2]
if cropped.size == 0:
continue
prompt = _TASK_PROMPTS.get(region.type, _DEFAULT_PROMPT)
tasks.append((idx, region, cropped, prompt))
if not tasks:
raw_content = self._call_vllm(padded, _DEFAULT_PROMPT)
markdown_content = self._formatter._clean_content(raw_content)
else:
# Parallel OCR calls
raw_results: dict[int, str] = {}
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tasks))) as ex:
future_map = {
ex.submit(self._call_vllm, cropped, prompt): idx
for idx, region, cropped, prompt in tasks
}
for future in as_completed(future_map):
idx = future_map[future]
try:
raw_results[idx] = future.result()
except Exception:
raw_results[idx] = ""
# Build structured region dicts for GLMResultFormatter
region_dicts = []
for idx, region, _cropped, _prompt in tasks:
region_dicts.append(
{
"index": idx,
"label": region.type,
"native_label": region.native_label,
"content": raw_results.get(idx, ""),
"bbox_2d": self._normalize_bbox(region.bbox, img_w, img_h),
}
)
# 4. GLM-OCR postprocessing: clean, format, merge, bullets
markdown_content = self._formatter.process(region_dicts)
# 5. LaTeX math error correction (our existing pipeline)
markdown_content = _postprocess_markdown(markdown_content)
# 6. Format conversion
latex, mathml, mml = "", "", ""
if markdown_content and self.converter:
fmt = self.converter.convert_to_formats(markdown_content)
latex, mathml, mml = fmt.latex, fmt.mathml, fmt.mml
return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml}
if __name__ == "__main__":
mineru_service = MineruOCRService()
image = cv2.imread("test/formula2.jpg")