157 lines
4.5 KiB
Python
157 lines
4.5 KiB
Python
"""PP-DocLayoutV2 wrapper for document layout detection."""
|
|
|
|
import numpy as np
|
|
|
|
from app.schemas.image import LayoutInfo, LayoutRegion
|
|
from app.core.config import get_settings
|
|
from paddleocr import LayoutDetection
|
|
from typing import Optional
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class LayoutDetector:
|
|
"""Layout detector for PP-DocLayoutV2."""
|
|
|
|
_layout_detector: Optional[LayoutDetection] = None
|
|
|
|
# PP-DocLayoutV2 class ID to label mapping
|
|
CLS_ID_TO_LABEL: dict[int, str] = {
|
|
0: "abstract",
|
|
1: "algorithm",
|
|
2: "aside_text",
|
|
3: "chart",
|
|
4: "content",
|
|
5: "display_formula",
|
|
6: "doc_title",
|
|
7: "figure_title",
|
|
8: "footer",
|
|
9: "footer_image",
|
|
10: "footnote",
|
|
11: "formula_number",
|
|
12: "header",
|
|
13: "header_image",
|
|
14: "image",
|
|
15: "inline_formula",
|
|
16: "number",
|
|
17: "paragraph_title",
|
|
18: "reference",
|
|
19: "reference_content",
|
|
20: "seal",
|
|
21: "table",
|
|
22: "text",
|
|
23: "vertical_text",
|
|
24: "vision_footnote",
|
|
}
|
|
|
|
# Mapping from raw labels to normalized region types
|
|
LABEL_TO_TYPE: dict[str, str] = {
|
|
# Text types
|
|
"abstract": "text",
|
|
"algorithm": "text",
|
|
"aside_text": "text",
|
|
"content": "text",
|
|
"doc_title": "text",
|
|
"footer": "text",
|
|
"footnote": "text",
|
|
"header": "text",
|
|
"number": "text",
|
|
"paragraph_title": "text",
|
|
"reference": "text",
|
|
"reference_content": "text",
|
|
"text": "text",
|
|
"vertical_text": "text",
|
|
"vision_footnote": "text",
|
|
# Formula types
|
|
"display_formula": "formula",
|
|
"inline_formula": "formula",
|
|
"formula_number": "formula",
|
|
# Table types
|
|
"table": "table",
|
|
# Figure types
|
|
"chart": "figure",
|
|
"figure_title": "figure",
|
|
"footer_image": "figure",
|
|
"header_image": "figure",
|
|
"image": "figure",
|
|
"seal": "figure",
|
|
}
|
|
|
|
def __init__(self):
|
|
"""Initialize layout detector.
|
|
|
|
Args:
|
|
"""
|
|
_ = self._get_layout_detector()
|
|
|
|
def _get_layout_detector(self):
|
|
"""Get or create LayoutDetection instance."""
|
|
if LayoutDetector._layout_detector is None:
|
|
LayoutDetector._layout_detector = LayoutDetection(model_name="PP-DocLayoutV2")
|
|
return LayoutDetector._layout_detector
|
|
|
|
def detect(self, image: np.ndarray) -> LayoutInfo:
|
|
"""Detect layout of the image using PP-DocLayoutV2.
|
|
|
|
Args:
|
|
image: Input image as numpy array.
|
|
|
|
Returns:
|
|
LayoutInfo with detected regions and flags.
|
|
"""
|
|
layout_detector = self._get_layout_detector()
|
|
result = layout_detector.predict(image)
|
|
|
|
# Parse the result
|
|
regions: list[LayoutRegion] = []
|
|
mixed_recognition = False
|
|
|
|
# Handle result format: [{'input_path': ..., 'page_index': None, 'boxes': [...]}]
|
|
if isinstance(result, list) and len(result) > 0:
|
|
first_result = result[0]
|
|
if isinstance(first_result, dict) and "boxes" in first_result:
|
|
boxes = first_result.get("boxes", [])
|
|
else:
|
|
boxes = []
|
|
else:
|
|
boxes = []
|
|
|
|
for box in boxes:
|
|
cls_id = box.get("cls_id")
|
|
label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other")
|
|
score = box.get("score", 0.0)
|
|
coordinate = box.get("coordinate", [0, 0, 0, 0])
|
|
|
|
# Normalize label to region type
|
|
region_type = self.LABEL_TO_TYPE.get(label, "text")
|
|
|
|
regions.append(LayoutRegion(
|
|
type=region_type,
|
|
bbox=coordinate,
|
|
confidence=score,
|
|
score=score,
|
|
))
|
|
|
|
|
|
mixed_recognition = any(region.type == "text" and region.score > 0.85 for region in regions)
|
|
|
|
return LayoutInfo(regions=regions, MixedRecognition=mixed_recognition)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import cv2
|
|
from app.services.image_processor import ImageProcessor
|
|
|
|
layout_detector = LayoutDetector()
|
|
image_path = "test/timeout.png"
|
|
|
|
image = cv2.imread(image_path)
|
|
image_processor = ImageProcessor(padding_ratio=0.15)
|
|
image = image_processor.add_padding(image)
|
|
|
|
# Save the padded image for debugging
|
|
cv2.imwrite("debug_padded_image.png", image)
|
|
|
|
|
|
layout_info = layout_detector.detect(image)
|
|
print(layout_info) |