diff --git a/app/api/v1/endpoints/convert.py b/app/api/v1/endpoints/convert.py index e3575ad..90a572d 100644 --- a/app/api/v1/endpoints/convert.py +++ b/app/api/v1/endpoints/convert.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import Response from app.core.dependencies import get_converter -from app.schemas.convert import MarkdownToDocxRequest, LatexToOmmlRequest, LatexToOmmlResponse +from app.schemas.convert import LatexToOmmlRequest, LatexToOmmlResponse, MarkdownToDocxRequest from app.services.converter import Converter router = APIRouter() diff --git a/app/api/v1/endpoints/image.py b/app/api/v1/endpoints/image.py index 1c55fd6..b526098 100644 --- a/app/api/v1/endpoints/image.py +++ b/app/api/v1/endpoints/image.py @@ -6,10 +6,10 @@ import uuid from fastapi import APIRouter, Depends, HTTPException, Request, Response from app.core.dependencies import ( - get_image_processor, get_glmocr_endtoend_service, + get_image_processor, ) -from app.core.logging_config import get_logger, RequestIDAdapter +from app.core.logging_config import RequestIDAdapter, get_logger from app.schemas.image import ImageOCRRequest, ImageOCRResponse from app.services.image_processor import ImageProcessor from app.services.ocr_service import GLMOCREndToEndService diff --git a/app/core/dependencies.py b/app/core/dependencies.py index 50ddbdb..5fe6935 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -1,10 +1,10 @@ """Application dependencies.""" +from app.core.config import get_settings +from app.services.converter import Converter from app.services.image_processor import ImageProcessor from app.services.layout_detector import LayoutDetector from app.services.ocr_service import GLMOCREndToEndService -from app.services.converter import Converter -from app.core.config import get_settings # Global instances (initialized on startup) _layout_detector: LayoutDetector | None = None diff --git a/app/core/logging_config.py b/app/core/logging_config.py index 1914801..36069dd 100644 --- a/app/core/logging_config.py +++ b/app/core/logging_config.py @@ -3,7 +3,7 @@ import logging import logging.handlers from pathlib import Path -from typing import Any, Optional +from typing import Any from app.core.config import get_settings @@ -18,10 +18,10 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler) interval: int = 1, backupCount: int = 30, maxBytes: int = 100 * 1024 * 1024, # 100MB - encoding: Optional[str] = None, + encoding: str | None = None, delay: bool = False, utc: bool = False, - atTime: Optional[Any] = None, + atTime: Any | None = None, ): """Initialize handler with both time and size rotation. @@ -58,14 +58,14 @@ class TimedRotatingAndSizeFileHandler(logging.handlers.TimedRotatingFileHandler) if self.stream is None: self.stream = self._open() if self.maxBytes > 0: - msg = "%s\n" % self.format(record) + msg = f"{self.format(record)}\n" self.stream.seek(0, 2) # Seek to end if self.stream.tell() + len(msg) >= self.maxBytes: return True return False -def setup_logging(log_dir: Optional[str] = None) -> logging.Logger: +def setup_logging(log_dir: str | None = None) -> logging.Logger: """Setup application logging with rotation by day and size. Args: @@ -134,7 +134,7 @@ def setup_logging(log_dir: Optional[str] = None) -> logging.Logger: # Global logger instance -_logger: Optional[logging.Logger] = None +_logger: logging.Logger | None = None def get_logger() -> logging.Logger: diff --git a/app/schemas/convert.py b/app/schemas/convert.py index 068ceaa..a122b37 100644 --- a/app/schemas/convert.py +++ b/app/schemas/convert.py @@ -36,4 +36,3 @@ class LatexToOmmlResponse(BaseModel): """Response body for LaTeX to OMML conversion endpoint.""" omml: str = Field("", description="OMML (Office Math Markup Language) representation") - diff --git a/app/schemas/image.py b/app/schemas/image.py index f0d8f37..d1398cc 100644 --- a/app/schemas/image.py +++ b/app/schemas/image.py @@ -7,7 +7,9 @@ class LayoutRegion(BaseModel): """A detected layout region in the document.""" type: str = Field(..., description="Region type: text, formula, table, figure") - native_label: str = Field("", description="Raw label before type mapping (e.g. doc_title, formula_number)") + native_label: str = Field( + "", description="Raw label before type mapping (e.g. doc_title, formula_number)" + ) bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]") confidence: float = Field(..., description="Detection confidence score") score: float = Field(..., description="Detection score") @@ -41,10 +43,15 @@ class ImageOCRRequest(BaseModel): class ImageOCRResponse(BaseModel): """Response body for image OCR endpoint.""" - latex: str = Field("", description="LaTeX representation of the content (empty if mixed content)") + latex: str = Field( + "", description="LaTeX representation of the content (empty if mixed content)" + ) markdown: str = Field("", description="Markdown representation of the content") mathml: str = Field("", description="Standard MathML representation (empty if mixed content)") - mml: str = Field("", description="XML MathML with mml: namespace prefix (empty if mixed content)") + mml: str = Field( + "", description="XML MathML with mml: namespace prefix (empty if mixed content)" + ) layout_info: LayoutInfo = Field(default_factory=LayoutInfo) - recognition_mode: str = Field("", description="Recognition mode used: mixed_recognition or formula_recognition") - + recognition_mode: str = Field( + "", description="Recognition mode used: mixed_recognition or formula_recognition" + ) diff --git a/app/services/converter.py b/app/services/converter.py index 792fac4..c175a9f 100644 --- a/app/services/converter.py +++ b/app/services/converter.py @@ -112,14 +112,18 @@ class Converter: # Pre-compiled regex patterns for preprocessing _RE_VSPACE = re.compile(r"\\\[1mm\]") _RE_BLOCK_FORMULA_INLINE = re.compile(r"([^\n])(\s*)\\\[(.*?)\\\]([^\n])", re.DOTALL) - _RE_BLOCK_FORMULA_LINE = re.compile(r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL) + _RE_BLOCK_FORMULA_LINE = re.compile( + r"^(\s*)\\\[(.*?)\\\](\s*)(?=\n|$)", re.MULTILINE | re.DOTALL + ) _RE_ARITHMATEX = re.compile(r'(.*?)') _RE_INLINE_SPACE = re.compile(r"(? str: @@ -583,7 +589,6 @@ class Converter: "⇓": "⇓", # Downarrow "↕": "↕", # updownarrow "⇕": "⇕", # Updownarrow - "≠": "≠", # ne "≪": "≪", # ll "≫": "≫", # gg "⩽": "⩽", # leqslant @@ -962,7 +967,7 @@ class Converter: """Export to DOCX format using pypandoc.""" extra_args = [ "--highlight-style=pygments", - f"--reference-doc=app/pkg/reference.docx", + "--reference-doc=app/pkg/reference.docx", ] pypandoc.convert_file( input_path, diff --git a/app/services/glm_postprocess.py b/app/services/glm_postprocess.py index 36256fd..a893d04 100644 --- a/app/services/glm_postprocess.py +++ b/app/services/glm_postprocess.py @@ -1,26 +1,10 @@ -"""GLM-OCR postprocessing logic adapted for this project. - -Ported from glm-ocr/glmocr/postprocess/result_formatter.py and -glm-ocr/glmocr/utils/result_postprocess_utils.py. - -Covers: -- Repeated-content / hallucination detection -- Per-region content cleaning and formatting (titles, bullets, formulas) -- formula_number merging (→ \\tag{}) -- Hyphenated text-block merging (via wordfreq) -- Missing bullet-point detection -""" - from __future__ import annotations import logging import re -import json - -logger = logging.getLogger(__name__) from collections import Counter from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple +from typing import Any try: from wordfreq import zipf_frequency @@ -29,13 +13,14 @@ try: except ImportError: _WORDFREQ_AVAILABLE = False +logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # result_postprocess_utils (ported) # --------------------------------------------------------------------------- -def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]: +def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None: """Detect and truncate a consecutively-repeated pattern. Returns the string with the repeat removed, or None if not found. @@ -49,7 +34,13 @@ def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 1 return None pattern = re.compile( - r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}", + r"(.{" + + str(min_unit_len) + + "," + + str(max_unit_len) + + r"}?)\1{" + + str(min_repeats - 1) + + ",}", re.DOTALL, ) match = pattern.search(s) @@ -83,7 +74,9 @@ def clean_repeated_content( if count >= line_threshold and (count / total_lines) >= 0.8: for i, line in enumerate(lines): if line == common: - consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common) + consecutive = sum( + 1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common + ) if consecutive >= 3: original_lines = content.split("\n") non_empty_count = 0 @@ -106,7 +99,7 @@ def clean_formula_number(number_content: str) -> str: # Strip display math delimiters for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]: if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end): - s = s[len(start):-len(end)].strip() + s = s[len(start) : -len(end)].strip() break # Strip CJK/ASCII parentheses if s.startswith("(") and s.endswith(")"): @@ -121,7 +114,7 @@ def clean_formula_number(number_content: str) -> str: # --------------------------------------------------------------------------- # Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping) -_LABEL_TO_CATEGORY: Dict[str, str] = { +_LABEL_TO_CATEGORY: dict[str, str] = { # text "abstract": "text", "algorithm": "text", @@ -157,7 +150,7 @@ class GLMResultFormatter: # Public entry-point # ------------------------------------------------------------------ # - def process(self, regions: List[Dict[str, Any]]) -> str: + def process(self, regions: list[dict[str, Any]]) -> str: """Run the full postprocessing pipeline and return Markdown. Args: @@ -175,7 +168,7 @@ class GLMResultFormatter: items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0)) # Per-region cleaning + formatting - processed: List[Dict] = [] + processed: list[dict] = [] for item in items: item["native_label"] = item.get("native_label", item.get("label", "text")) item["label"] = self._map_label(item.get("label", "text"), item["native_label"]) @@ -199,7 +192,7 @@ class GLMResultFormatter: processed = self._format_bullet_points(processed) # Assemble Markdown - parts: List[str] = [] + parts: list[str] = [] for item in processed: content = item.get("content") or "" if item["label"] == "image": @@ -263,11 +256,15 @@ class GLMResultFormatter: if label == "formula": content = content.strip() for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]: - if content.startswith(s) and content.endswith(e): - content = content[len(s) : -len(e)].strip() + if content.startswith(s): + content = content[len(s) :].strip() + if content.endswith(e): + content = content[: -len(e)].strip() break if not content: - logger.warning("Skipping formula region with empty content after stripping delimiters") + logger.warning( + "Skipping formula region with empty content after stripping delimiters" + ) return "" content = "$$\n" + content + "\n$$" @@ -296,12 +293,12 @@ class GLMResultFormatter: # Structural merges # ------------------------------------------------------------------ # - def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]: + def _merge_formula_numbers(self, items: list[dict]) -> list[dict]: """Merge formula_number region into adjacent formula with \\tag{}.""" if not items: return items - merged: List[Dict] = [] + merged: list[dict] = [] skip: set = set() for i, block in enumerate(items): @@ -317,7 +314,9 @@ class GLMResultFormatter: formula_content = items[i + 1].get("content", "") merged_block = deepcopy(items[i + 1]) if formula_content.endswith("\n$$"): - merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + merged_block["content"] = ( + formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + ) merged.append(merged_block) skip.add(i + 1) continue # always skip the formula_number block itself @@ -329,7 +328,9 @@ class GLMResultFormatter: formula_content = block.get("content", "") merged_block = deepcopy(block) if formula_content.endswith("\n$$"): - merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + merged_block["content"] = ( + formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" + ) merged.append(merged_block) skip.add(i + 1) continue @@ -340,12 +341,12 @@ class GLMResultFormatter: block["index"] = i return merged - def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]: + def _merge_text_blocks(self, items: list[dict]) -> list[dict]: """Merge hyphenated text blocks when the combined word is valid (wordfreq).""" if not items or not _WORDFREQ_AVAILABLE: return items - merged: List[Dict] = [] + merged: list[dict] = [] skip: set = set() for i, block in enumerate(items): @@ -389,7 +390,9 @@ class GLMResultFormatter: block["index"] = i return merged - def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]: + def _format_bullet_points( + self, items: list[dict], left_align_threshold: float = 10.0 + ) -> list[dict]: """Add missing bullet prefix when a text block is sandwiched between two bullet items.""" if len(items) < 3: return items diff --git a/app/services/layout_detector.py b/app/services/layout_detector.py index 36eb1b9..10a366a 100644 --- a/app/services/layout_detector.py +++ b/app/services/layout_detector.py @@ -1,12 +1,11 @@ """PP-DocLayoutV3 wrapper for document layout detection.""" import numpy as np - -from app.schemas.image import LayoutInfo, LayoutRegion -from app.core.config import get_settings -from app.services.layout_postprocess import apply_layout_postprocess from paddleocr import LayoutDetection -from typing import Optional + +from app.core.config import get_settings +from app.schemas.image import LayoutInfo, LayoutRegion +from app.services.layout_postprocess import apply_layout_postprocess settings = get_settings() @@ -14,7 +13,7 @@ settings = get_settings() class LayoutDetector: """Layout detector for PP-DocLayoutV2.""" - _layout_detector: Optional[LayoutDetection] = None + _layout_detector: LayoutDetection | None = None # PP-DocLayoutV2 class ID to label mapping CLS_ID_TO_LABEL: dict[int, str] = { @@ -156,10 +155,11 @@ class LayoutDetector: if __name__ == "__main__": import cv2 + from app.core.config import get_settings - from app.services.image_processor import ImageProcessor from app.services.converter import Converter - from app.services.ocr_service import OCRService + from app.services.image_processor import ImageProcessor + from app.services.ocr_service import GLMOCREndToEndService settings = get_settings() @@ -169,15 +169,15 @@ if __name__ == "__main__": converter = Converter() # Initialize OCR service - ocr_service = OCRService( - vl_server_url=settings.paddleocr_vl_url, + ocr_service = GLMOCREndToEndService( + vl_server_url=settings.glm_ocr_url, layout_detector=layout_detector, image_processor=image_processor, converter=converter, ) # Load test image - image_path = "test/timeout.jpg" + image_path = "test/image2.png" image = cv2.imread(image_path) if image is None: diff --git a/app/services/layout_postprocess.py b/app/services/layout_postprocess.py index 2eb9a7a..6088f60 100644 --- a/app/services/layout_postprocess.py +++ b/app/services/layout_postprocess.py @@ -15,16 +15,14 @@ the quality of the GLM-OCR SDK's layout pipeline. from __future__ import annotations -from typing import Dict, List, Optional, Tuple, Union - import numpy as np - # --------------------------------------------------------------------------- # Primitive geometry helpers # --------------------------------------------------------------------------- -def iou(box1: List[float], box2: List[float]) -> float: + +def iou(box1: list[float], box2: list[float]) -> float: """Compute IoU of two bounding boxes [x1, y1, x2, y2].""" x1, y1, x2, y2 = box1 x1_p, y1_p, x2_p, y2_p = box2 @@ -41,7 +39,7 @@ def iou(box1: List[float], box2: List[float]) -> float: return inter_area / float(box1_area + box2_area - inter_area) -def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool: +def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool: """Return True if box1 is contained within box2 (overlap ratio >= threshold). box format: [cls_id, score, x1, y1, x2, y2] @@ -66,11 +64,12 @@ def is_contained(box1: List[float], box2: List[float], overlap_threshold: float # NMS # --------------------------------------------------------------------------- + def nms( boxes: np.ndarray, iou_same: float = 0.6, iou_diff: float = 0.98, -) -> List[int]: +) -> list[int]: """NMS with separate IoU thresholds for same-class and cross-class overlaps. Args: @@ -83,7 +82,7 @@ def nms( """ scores = boxes[:, 1] indices = np.argsort(scores)[::-1].tolist() - selected: List[int] = [] + selected: list[int] = [] while indices: current = indices[0] @@ -114,10 +113,10 @@ _PRESERVE_LABELS = {"image", "seal", "chart"} def check_containment( boxes: np.ndarray, - preserve_cls_ids: Optional[set] = None, - category_index: Optional[int] = None, - mode: Optional[str] = None, -) -> Tuple[np.ndarray, np.ndarray]: + preserve_cls_ids: set | None = None, + category_index: int | None = None, + mode: str | None = None, +) -> tuple[np.ndarray, np.ndarray]: """Compute containment flags for each box. Args: @@ -160,9 +159,10 @@ def check_containment( # Box expansion (unclip) # --------------------------------------------------------------------------- + def unclip_boxes( boxes: np.ndarray, - unclip_ratio: Union[float, Tuple[float, float], Dict, List, None], + unclip_ratio: float | tuple[float, float] | dict | list | None, ) -> np.ndarray: """Expand bounding boxes by the given ratio. @@ -215,13 +215,14 @@ def unclip_boxes( # Main entry-point # --------------------------------------------------------------------------- + def apply_layout_postprocess( - boxes: List[Dict], - img_size: Tuple[int, int], + boxes: list[dict], + img_size: tuple[int, int], layout_nms: bool = True, - layout_unclip_ratio: Union[float, Tuple, Dict, None] = None, - layout_merge_bboxes_mode: Union[str, Dict, None] = "large", -) -> List[Dict]: + layout_unclip_ratio: float | tuple | dict | None = None, + layout_merge_bboxes_mode: str | dict | None = "large", +) -> list[dict]: """Apply GLM-OCR layout post-processing to PaddleOCR detection results. Args: @@ -250,7 +251,7 @@ def apply_layout_postprocess( arr_rows.append([cls_id, score, x1, y1, x2, y2]) boxes_array = np.array(arr_rows, dtype=float) - all_labels: List[str] = [b.get("label", "") for b in boxes] + all_labels: list[str] = [b.get("label", "") for b in boxes] # 1. NMS ---------------------------------------------------------------- # if layout_nms and len(boxes_array) > 1: @@ -262,17 +263,14 @@ def apply_layout_postprocess( if len(boxes_array) > 1: img_area = img_width * img_height area_thres = 0.82 if img_width > img_height else 0.93 - image_cls_ids = { - int(boxes_array[i, 0]) - for i, lbl in enumerate(all_labels) - if lbl == "image" - } keep_mask = np.ones(len(boxes_array), dtype=bool) for i, lbl in enumerate(all_labels): if lbl == "image": x1, y1, x2, y2 = boxes_array[i, 2:6] - x1 = max(0.0, x1); y1 = max(0.0, y1) - x2 = min(float(img_width), x2); y2 = min(float(img_height), y2) + x1 = max(0.0, x1) + y1 = max(0.0, y1) + x2 = min(float(img_width), x2) + y2 = min(float(img_height), y2) if (x2 - x1) * (y2 - y1) > area_thres * img_area: keep_mask[i] = False boxes_array = boxes_array[keep_mask] @@ -281,9 +279,7 @@ def apply_layout_postprocess( # 3. Containment analysis (merge_bboxes_mode) -------------------------- # if layout_merge_bboxes_mode and len(boxes_array) > 1: preserve_cls_ids = { - int(boxes_array[i, 0]) - for i, lbl in enumerate(all_labels) - if lbl in _PRESERVE_LABELS + int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS } if isinstance(layout_merge_bboxes_mode, str): @@ -321,7 +317,7 @@ def apply_layout_postprocess( boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio) # 5. Clamp to image boundaries + skip invalid -------------------------- # - result: List[Dict] = [] + result: list[dict] = [] for i, row in enumerate(boxes_array): cls_id = int(row[0]) score = float(row[1]) @@ -333,11 +329,13 @@ def apply_layout_postprocess( if x1 >= x2 or y1 >= y2: continue - result.append({ - "cls_id": cls_id, - "label": all_labels[i], - "score": score, - "coordinate": [int(x1), int(y1), int(x2), int(y2)], - }) + result.append( + { + "cls_id": cls_id, + "label": all_labels[i], + "score": score, + "coordinate": [int(x1), int(y1), int(x2), int(y2)], + } + ) return result diff --git a/app/services/ocr_service.py b/app/services/ocr_service.py index 3ccfb53..4a69a03 100644 --- a/app/services/ocr_service.py +++ b/app/services/ocr_service.py @@ -878,12 +878,9 @@ class GLMOCREndToEndService(OCRServiceBase): Returns: Dict with 'markdown', 'latex', 'mathml', 'mml' keys. """ - # 1. Padding - padded = self.image_processor.add_padding(image) - img_h, img_w = padded.shape[:2] - - # 2. Layout detection - layout_info = self.layout_detector.detect(padded) + # 1. Layout detection + img_h, img_w = image.shape[:2] + layout_info = self.layout_detector.detect(image) # Sort regions in reading order: top-to-bottom, left-to-right layout_info.regions.sort(key=lambda r: (r.bbox[1], r.bbox[0])) @@ -892,7 +889,7 @@ class GLMOCREndToEndService(OCRServiceBase): if not layout_info.regions: # No layout detected → assume it's a formula, use formula recognition logger.info("No layout regions detected, treating image as formula") - raw_content = self._call_vllm(padded, _TASK_PROMPTS["formula"]) + raw_content = self._call_vllm(image, _TASK_PROMPTS["formula"]) # Format as display formula markdown formatted_content = raw_content.strip() if not (formatted_content.startswith("$$") and formatted_content.endswith("$$")): @@ -905,7 +902,7 @@ class GLMOCREndToEndService(OCRServiceBase): if region.type == "figure": continue x1, y1, x2, y2 = (int(c) for c in region.bbox) - cropped = padded[y1:y2, x1:x2] + cropped = image[y1:y2, x1:x2] if cropped.size == 0 or cropped.shape[0] < 10 or cropped.shape[1] < 10: logger.warning( "Skipping region idx=%d (label=%s): crop too small %s", @@ -918,7 +915,7 @@ class GLMOCREndToEndService(OCRServiceBase): tasks.append((idx, region, cropped, prompt)) if not tasks: - raw_content = self._call_vllm(padded, _DEFAULT_PROMPT) + raw_content = self._call_vllm(image, _DEFAULT_PROMPT) markdown_content = self._formatter._clean_content(raw_content) else: # Parallel OCR calls @@ -965,17 +962,3 @@ class GLMOCREndToEndService(OCRServiceBase): logger.warning("Format conversion failed, returning empty latex/mathml/mml: %s", e) return {"markdown": markdown_content, "latex": latex, "mathml": mathml, "mml": mml} - - -if __name__ == "__main__": - mineru_service = MineruOCRService() - image = cv2.imread("test/formula2.jpg") - image_numpy = np.array(image) - # Encode image to bytes (as done in API layer) - success, encoded_image = cv2.imencode(".png", image_numpy) - if not success: - raise RuntimeError("Failed to encode image") - image_bytes = BytesIO(encoded_image.tobytes()) - image_bytes.seek(0) - ocr_result = mineru_service.recognize(image_bytes) - print(ocr_result) diff --git a/tests/api/v1/endpoints/test_image_endpoint.py b/tests/api/v1/endpoints/test_image_endpoint.py index 5868c05..b9814cb 100644 --- a/tests/api/v1/endpoints/test_image_endpoint.py +++ b/tests/api/v1/endpoints/test_image_endpoint.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -35,7 +34,9 @@ def test_image_endpoint_requires_exactly_one_of_image_url_or_image_base64(): client = _build_client() missing = client.post("/ocr", json={}) - both = client.post("/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"}) + both = client.post( + "/ocr", json={"image_url": "https://example.com/a.png", "image_base64": "abc"} + ) assert missing.status_code == 422 assert both.status_code == 422 diff --git a/tests/services/test_glm_postprocess.py b/tests/services/test_glm_postprocess.py index 1f241bc..d9bc464 100644 --- a/tests/services/test_glm_postprocess.py +++ b/tests/services/test_glm_postprocess.py @@ -57,12 +57,22 @@ def test_merge_formula_numbers_merges_before_and_after_formula(): before = formatter._merge_formula_numbers( [ {"index": 0, "label": "text", "native_label": "formula_number", "content": "(1)"}, - {"index": 1, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, + { + "index": 1, + "label": "formula", + "native_label": "display_formula", + "content": "$$\nx+y\n$$", + }, ] ) after = formatter._merge_formula_numbers( [ - {"index": 0, "label": "formula", "native_label": "display_formula", "content": "$$\nx+y\n$$"}, + { + "index": 0, + "label": "formula", + "native_label": "display_formula", + "content": "$$\nx+y\n$$", + }, {"index": 1, "label": "text", "native_label": "formula_number", "content": "(2)"}, ] ) diff --git a/tests/services/test_layout_detector.py b/tests/services/test_layout_detector.py index db8584a..d62776b 100644 --- a/tests/services/test_layout_detector.py +++ b/tests/services/test_layout_detector.py @@ -23,7 +23,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch): calls = {} - def fake_apply_layout_postprocess(boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode): + def fake_apply_layout_postprocess( + boxes, img_size, layout_nms, layout_unclip_ratio, layout_merge_bboxes_mode + ): calls["args"] = { "boxes": boxes, "img_size": img_size, @@ -33,7 +35,9 @@ def test_detect_applies_postprocess_and_keeps_native_label(monkeypatch): } return [boxes[0], boxes[2]] - monkeypatch.setattr("app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess) + monkeypatch.setattr( + "app.services.layout_detector.apply_layout_postprocess", fake_apply_layout_postprocess + ) image = np.zeros((200, 100, 3), dtype=np.uint8) info = detector.detect(image) diff --git a/tests/services/test_layout_postprocess.py b/tests/services/test_layout_postprocess.py index be32f29..3bf3506 100644 --- a/tests/services/test_layout_postprocess.py +++ b/tests/services/test_layout_postprocess.py @@ -146,6 +146,4 @@ def test_apply_layout_postprocess_clamps_skips_invalid_and_filters_large_image() layout_merge_bboxes_mode=None, ) - assert result == [ - {"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]} - ] + assert result == [{"cls_id": 0, "label": "text", "score": 0.95, "coordinate": [0, 0, 40, 50]}] diff --git a/tests/services/test_ocr_service.py b/tests/services/test_ocr_service.py index d57b451..c801ffb 100644 --- a/tests/services/test_ocr_service.py +++ b/tests/services/test_ocr_service.py @@ -46,7 +46,9 @@ def test_encode_region_returns_decodable_base64_jpeg(): image[:, :] = [0, 128, 255] encoded = service._encode_region(image) - decoded = cv2.imdecode(np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR) + decoded = cv2.imdecode( + np.frombuffer(base64.b64decode(encoded), dtype=np.uint8), cv2.IMREAD_COLOR + ) assert decoded.shape[:2] == image.shape[:2] @@ -71,7 +73,9 @@ def test_call_vllm_builds_messages_and_returns_content(): assert captured["model"] == "glm-ocr" assert captured["max_tokens"] == 1024 assert captured["messages"][0]["content"][0]["type"] == "image_url" - assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith("data:image/jpeg;base64,") + assert captured["messages"][0]["content"][0]["image_url"]["url"].startswith( + "data:image/jpeg;base64," + ) assert captured["messages"][0]["content"][1] == {"type": "text", "text": "Formula Recognition:"} @@ -98,9 +102,19 @@ def test_recognize_falls_back_to_full_image_when_no_layout_regions(monkeypatch): def test_recognize_skips_figures_keeps_order_and_postprocesses(monkeypatch): regions = [ - LayoutRegion(type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9), - LayoutRegion(type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8), - LayoutRegion(type="formula", native_label="display_formula", bbox=[20, 20, 40, 40], confidence=0.95, score=0.95), + LayoutRegion( + type="text", native_label="doc_title", bbox=[0, 0, 10, 10], confidence=0.9, score=0.9 + ), + LayoutRegion( + type="figure", native_label="image", bbox=[10, 10, 20, 20], confidence=0.8, score=0.8 + ), + LayoutRegion( + type="formula", + native_label="display_formula", + bbox=[20, 20, 40, 40], + confidence=0.95, + score=0.95, + ), ] service = _build_service(regions=regions) image = np.zeros((40, 40, 3), dtype=np.uint8)