feat: add glm ocr
This commit is contained in:
@@ -87,11 +87,11 @@ class LayoutDetector:
|
||||
def _get_layout_detector(self):
|
||||
"""Get or create LayoutDetection instance."""
|
||||
if LayoutDetector._layout_detector is None:
|
||||
LayoutDetector._layout_detector = LayoutDetection(model_name="PP-DocLayoutV2")
|
||||
LayoutDetector._layout_detector = LayoutDetection(model_name="PP-DocLayoutV3")
|
||||
return LayoutDetector._layout_detector
|
||||
|
||||
def detect(self, image: np.ndarray) -> LayoutInfo:
|
||||
"""Detect layout of the image using PP-DocLayoutV2.
|
||||
"""Detect layout of the image using PP-DocLayoutV3.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array.
|
||||
@@ -125,13 +125,14 @@ class LayoutDetector:
|
||||
# Normalize label to region type
|
||||
region_type = self.LABEL_TO_TYPE.get(label, "text")
|
||||
|
||||
regions.append(LayoutRegion(
|
||||
type=region_type,
|
||||
bbox=coordinate,
|
||||
confidence=score,
|
||||
score=score,
|
||||
))
|
||||
|
||||
regions.append(
|
||||
LayoutRegion(
|
||||
type=region_type,
|
||||
bbox=coordinate,
|
||||
confidence=score,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
mixed_recognition = any(region.type == "text" and region.score > 0.85 for region in regions)
|
||||
|
||||
@@ -144,14 +145,14 @@ if __name__ == "__main__":
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.converter import Converter
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# Initialize dependencies
|
||||
layout_detector = LayoutDetector()
|
||||
image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio)
|
||||
converter = Converter()
|
||||
|
||||
|
||||
# Initialize OCR service
|
||||
ocr_service = OCRService(
|
||||
vl_server_url=settings.paddleocr_vl_url,
|
||||
@@ -159,20 +160,20 @@ if __name__ == "__main__":
|
||||
image_processor=image_processor,
|
||||
converter=converter,
|
||||
)
|
||||
|
||||
|
||||
# Load test image
|
||||
image_path = "test/complex_formula.png"
|
||||
image_path = "test/timeout.jpg"
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
|
||||
if image is None:
|
||||
print(f"Failed to load image: {image_path}")
|
||||
else:
|
||||
print(f"Image loaded: {image.shape}")
|
||||
|
||||
|
||||
# Run OCR recognition
|
||||
result = ocr_service.recognize(image)
|
||||
|
||||
|
||||
print("\n=== OCR Result ===")
|
||||
print(f"Markdown:\n{result['markdown']}")
|
||||
print(f"\nLaTeX:\n{result['latex']}")
|
||||
print(f"\nMathML:\n{result['mathml']}")
|
||||
print(f"\nMathML:\n{result['mathml']}")
|
||||
|
||||
@@ -481,6 +481,92 @@ class OCRService(OCRServiceBase):
|
||||
return self._recognize_formula(image)
|
||||
|
||||
|
||||
class GLMOCRService(OCRServiceBase):
|
||||
"""Service for OCR using GLM-4V model via vLLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vl_server_url: str,
|
||||
image_processor: ImageProcessor,
|
||||
converter: Converter,
|
||||
):
|
||||
"""Initialize GLM OCR service.
|
||||
|
||||
Args:
|
||||
vl_server_url: URL of the vLLM server for GLM-4V (default: http://127.0.0.1:8002/v1).
|
||||
image_processor: Image processor instance.
|
||||
converter: Converter instance for format conversion.
|
||||
"""
|
||||
self.vl_server_url = vl_server_url or settings.glm_ocr_url
|
||||
self.image_processor = image_processor
|
||||
self.converter = converter
|
||||
self.openai_client = OpenAI(api_key="EMPTY", base_url=self.vl_server_url, timeout=3600)
|
||||
|
||||
def _recognize_formula(self, image: np.ndarray) -> dict:
|
||||
"""Recognize formula/math content using GLM-4V.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array in BGR format.
|
||||
|
||||
Returns:
|
||||
Dict with 'latex', 'markdown', 'mathml', 'mml' keys.
|
||||
"""
|
||||
try:
|
||||
# Add padding to image
|
||||
padded_image = self.image_processor.add_padding(image)
|
||||
|
||||
# Encode image to base64
|
||||
success, encoded_image = cv2.imencode(".png", padded_image)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to encode image")
|
||||
|
||||
image_base64 = base64.b64encode(encoded_image.tobytes()).decode("utf-8")
|
||||
image_url = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
# Call OpenAI-compatible API with formula recognition prompt
|
||||
prompt = "Formula Recognition:"
|
||||
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": prompt}]}]
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="glm-ocr",
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
markdown_content = response.choices[0].message.content
|
||||
|
||||
# Process LaTeX delimiters
|
||||
if markdown_content.startswith(r"\[") or markdown_content.startswith(r"\("):
|
||||
markdown_content = markdown_content.replace(r"\[", "$$").replace(r"\(", "$$")
|
||||
markdown_content = markdown_content.replace(r"\]", "$$").replace(r"\)", "$$")
|
||||
elif not markdown_content.startswith("$$") and not markdown_content.startswith("$"):
|
||||
markdown_content = f"$${markdown_content}$$"
|
||||
|
||||
# Apply postprocessing
|
||||
markdown_content = _postprocess_markdown(markdown_content)
|
||||
convert_result = self.converter.convert_to_formats(markdown_content)
|
||||
|
||||
return {
|
||||
"latex": convert_result.latex,
|
||||
"mathml": convert_result.mathml,
|
||||
"mml": convert_result.mml,
|
||||
"markdown": markdown_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"GLM formula recognition failed: {e}") from e
|
||||
|
||||
def recognize(self, image: np.ndarray) -> dict:
|
||||
"""Recognize content using GLM-4V.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array in BGR format.
|
||||
|
||||
Returns:
|
||||
Dict with 'latex', 'markdown', 'mathml', 'mml' keys.
|
||||
"""
|
||||
return self._recognize_formula(image)
|
||||
|
||||
|
||||
class MineruOCRService(OCRServiceBase):
|
||||
"""Service for OCR using local file_parse API."""
|
||||
|
||||
@@ -490,6 +576,7 @@ class MineruOCRService(OCRServiceBase):
|
||||
image_processor: Optional[ImageProcessor] = None,
|
||||
converter: Optional[Converter] = None,
|
||||
paddleocr_vl_url: str = "http://localhost:8001/v1",
|
||||
layout_detector: Optional[LayoutDetector] = None,
|
||||
):
|
||||
"""Initialize Local API service.
|
||||
|
||||
@@ -573,7 +660,7 @@ class MineruOCRService(OCRServiceBase):
|
||||
Dict with 'markdown', 'latex', 'mathml' keys.
|
||||
"""
|
||||
try:
|
||||
if self.image_processor:
|
||||
if self.image_processor and get_settings().is_padding:
|
||||
image = self.image_processor.add_padding(image)
|
||||
|
||||
# Convert numpy array to image bytes
|
||||
@@ -647,7 +734,7 @@ class MineruOCRService(OCRServiceBase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
mineru_service = MineruOCRService()
|
||||
image = cv2.imread("test/complex_formula.png")
|
||||
image = cv2.imread("test/formula2.jpg")
|
||||
image_numpy = np.array(image)
|
||||
ocr_result = mineru_service.recognize(image_numpy)
|
||||
print(ocr_result)
|
||||
|
||||
Reference in New Issue
Block a user