fix: refact logic
This commit is contained in:
@@ -1,122 +1,157 @@
|
||||
"""DocLayout-YOLO wrapper for document layout detection."""
|
||||
"""PP-DocLayoutV2 wrapper for document layout detection."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.schemas.image import LayoutInfo, LayoutRegion
|
||||
from app.core.config import get_settings
|
||||
from paddleocr import LayoutDetection
|
||||
from typing import Optional
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class LayoutDetector:
|
||||
"""Wrapper for DocLayout-YOLO model."""
|
||||
"""Layout detector for PP-DocLayoutV2."""
|
||||
|
||||
# Class names from DocLayout-YOLO
|
||||
CLASS_NAMES = {
|
||||
0: "title",
|
||||
1: "plain_text",
|
||||
2: "abandon",
|
||||
3: "figure",
|
||||
4: "figure_caption",
|
||||
5: "table",
|
||||
6: "table_caption",
|
||||
7: "table_footnote",
|
||||
8: "isolate_formula",
|
||||
9: "formula_caption",
|
||||
_layout_detector: Optional[LayoutDetection] = None
|
||||
|
||||
# PP-DocLayoutV2 class ID to label mapping
|
||||
CLS_ID_TO_LABEL: dict[int, str] = {
|
||||
0: "abstract",
|
||||
1: "algorithm",
|
||||
2: "aside_text",
|
||||
3: "chart",
|
||||
4: "content",
|
||||
5: "display_formula",
|
||||
6: "doc_title",
|
||||
7: "figure_title",
|
||||
8: "footer",
|
||||
9: "footer_image",
|
||||
10: "footnote",
|
||||
11: "formula_number",
|
||||
12: "header",
|
||||
13: "header_image",
|
||||
14: "image",
|
||||
15: "inline_formula",
|
||||
16: "number",
|
||||
17: "paragraph_title",
|
||||
18: "reference",
|
||||
19: "reference_content",
|
||||
20: "seal",
|
||||
21: "table",
|
||||
22: "text",
|
||||
23: "vertical_text",
|
||||
24: "vision_footnote",
|
||||
}
|
||||
|
||||
# Classes considered as plain text
|
||||
PLAIN_TEXT_CLASSES = {"title", "plain_text", "figure_caption", "table_caption", "table_footnote"}
|
||||
# Mapping from raw labels to normalized region types
|
||||
LABEL_TO_TYPE: dict[str, str] = {
|
||||
# Text types
|
||||
"abstract": "text",
|
||||
"algorithm": "text",
|
||||
"aside_text": "text",
|
||||
"content": "text",
|
||||
"doc_title": "text",
|
||||
"footer": "text",
|
||||
"footnote": "text",
|
||||
"header": "text",
|
||||
"number": "text",
|
||||
"paragraph_title": "text",
|
||||
"reference": "text",
|
||||
"reference_content": "text",
|
||||
"text": "text",
|
||||
"vertical_text": "text",
|
||||
"vision_footnote": "text",
|
||||
# Formula types
|
||||
"display_formula": "formula",
|
||||
"inline_formula": "formula",
|
||||
"formula_number": "formula",
|
||||
# Table types
|
||||
"table": "table",
|
||||
# Figure types
|
||||
"chart": "figure",
|
||||
"figure_title": "figure",
|
||||
"footer_image": "figure",
|
||||
"header_image": "figure",
|
||||
"image": "figure",
|
||||
"seal": "figure",
|
||||
}
|
||||
|
||||
# Classes considered as formula
|
||||
FORMULA_CLASSES = {"isolate_formula", "formula_caption"}
|
||||
|
||||
def __init__(self, model_path: str, confidence_threshold: float = 0.2):
|
||||
"""Initialize the layout detector.
|
||||
def __init__(self):
|
||||
"""Initialize layout detector.
|
||||
|
||||
Args:
|
||||
model_path: Path to the DocLayout-YOLO model weights.
|
||||
confidence_threshold: Minimum confidence for detections.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.model = None
|
||||
_ = self._get_layout_detector()
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load the DocLayout-YOLO model.
|
||||
def _get_layout_detector(self):
|
||||
"""Get or create LayoutDetection instance."""
|
||||
if LayoutDetector._layout_detector is None:
|
||||
LayoutDetector._layout_detector = LayoutDetection(model_name="PP-DocLayoutV2")
|
||||
return LayoutDetector._layout_detector
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model cannot be loaded.
|
||||
"""
|
||||
try:
|
||||
from doclayout_yolo import YOLOv10
|
||||
|
||||
self.model = YOLOv10(self.model_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load DocLayout-YOLO model: {e}") from e
|
||||
|
||||
def detect(self, image: np.ndarray, image_size: int = 1024) -> LayoutInfo:
|
||||
"""Detect document layout regions.
|
||||
def detect(self, image: np.ndarray) -> LayoutInfo:
|
||||
"""Detect layout of the image using PP-DocLayoutV2.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array in BGR format.
|
||||
image_size: Image size for prediction.
|
||||
image: Input image as numpy array.
|
||||
|
||||
Returns:
|
||||
LayoutInfo with detected regions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not loaded.
|
||||
LayoutInfo with detected regions and flags.
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||
|
||||
# Run prediction
|
||||
results = self.model.predict(
|
||||
image,
|
||||
imgsz=image_size,
|
||||
conf=self.confidence_threshold,
|
||||
device=settings.device,
|
||||
)
|
||||
layout_detector = self._get_layout_detector()
|
||||
result = layout_detector.predict(image)
|
||||
|
||||
# Parse the result
|
||||
regions: list[LayoutRegion] = []
|
||||
has_plain_text = False
|
||||
has_formula = False
|
||||
mixed_recognition = False
|
||||
|
||||
if results and len(results) > 0:
|
||||
result = results[0]
|
||||
if result.boxes is not None:
|
||||
for box in result.boxes:
|
||||
cls_id = int(box.cls[0].item())
|
||||
confidence = float(box.conf[0].item())
|
||||
bbox = box.xyxy[0].tolist()
|
||||
# Handle result format: [{'input_path': ..., 'page_index': None, 'boxes': [...]}]
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
first_result = result[0]
|
||||
if isinstance(first_result, dict) and "boxes" in first_result:
|
||||
boxes = first_result.get("boxes", [])
|
||||
else:
|
||||
boxes = []
|
||||
else:
|
||||
boxes = []
|
||||
|
||||
class_name = self.CLASS_NAMES.get(cls_id, f"unknown_{cls_id}")
|
||||
for box in boxes:
|
||||
cls_id = box.get("cls_id")
|
||||
label = box.get("label") or self.CLS_ID_TO_LABEL.get(cls_id, "other")
|
||||
score = box.get("score", 0.0)
|
||||
coordinate = box.get("coordinate", [0, 0, 0, 0])
|
||||
|
||||
# Map to simplified type
|
||||
if class_name in self.PLAIN_TEXT_CLASSES:
|
||||
region_type = "text"
|
||||
has_plain_text = True
|
||||
elif class_name in self.FORMULA_CLASSES:
|
||||
region_type = "formula"
|
||||
has_formula = True
|
||||
elif class_name in {"figure"}:
|
||||
region_type = "figure"
|
||||
elif class_name in {"table"}:
|
||||
region_type = "table"
|
||||
else:
|
||||
region_type = class_name
|
||||
# Normalize label to region type
|
||||
region_type = self.LABEL_TO_TYPE.get(label, "text")
|
||||
|
||||
regions.append(
|
||||
LayoutRegion(
|
||||
type=region_type,
|
||||
bbox=bbox,
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
regions.append(LayoutRegion(
|
||||
type=region_type,
|
||||
bbox=coordinate,
|
||||
confidence=score,
|
||||
score=score,
|
||||
))
|
||||
|
||||
return LayoutInfo(
|
||||
regions=regions,
|
||||
has_plain_text=has_plain_text,
|
||||
has_formula=has_formula,
|
||||
)
|
||||
|
||||
mixed_recognition = any(region.type == "text" and region.score > 0.85 for region in regions)
|
||||
|
||||
return LayoutInfo(regions=regions, MixedRecognition=mixed_recognition)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
from app.services.image_processor import ImageProcessor
|
||||
|
||||
layout_detector = LayoutDetector()
|
||||
image_path = "test/timeout.png"
|
||||
|
||||
image = cv2.imread(image_path)
|
||||
image_processor = ImageProcessor(padding_ratio=0.15)
|
||||
image = image_processor.add_padding(image)
|
||||
|
||||
# Save the padded image for debugging
|
||||
cv2.imwrite("debug_padded_image.png", image)
|
||||
|
||||
|
||||
layout_info = layout_detector.detect(image)
|
||||
print(layout_info)
|
||||
Reference in New Issue
Block a user