Files
doc_processer/app/services/layout_detector.py

120 lines
3.7 KiB
Python
Raw Normal View History

2025-12-29 17:34:58 +08:00
"""DocLayout-YOLO wrapper for document layout detection."""
import numpy as np
from app.schemas.image import LayoutInfo, LayoutRegion
class LayoutDetector:
"""Wrapper for DocLayout-YOLO model."""
# Class names from DocLayout-YOLO
CLASS_NAMES = {
0: "title",
1: "plain_text",
2: "abandon",
3: "figure",
4: "figure_caption",
5: "table",
6: "table_caption",
7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
}
# Classes considered as plain text
PLAIN_TEXT_CLASSES = {"title", "plain_text", "figure_caption", "table_caption", "table_footnote"}
# Classes considered as formula
FORMULA_CLASSES = {"isolate_formula", "formula_caption"}
def __init__(self, model_path: str, confidence_threshold: float = 0.2):
"""Initialize the layout detector.
Args:
model_path: Path to the DocLayout-YOLO model weights.
confidence_threshold: Minimum confidence for detections.
"""
self.model_path = model_path
self.confidence_threshold = confidence_threshold
self.model = None
def load_model(self) -> None:
"""Load the DocLayout-YOLO model.
Raises:
RuntimeError: If model cannot be loaded.
"""
try:
from doclayout_yolo import YOLOv10
self.model = YOLOv10(self.model_path)
except Exception as e:
raise RuntimeError(f"Failed to load DocLayout-YOLO model: {e}") from e
def detect(self, image: np.ndarray, image_size: int = 1024) -> LayoutInfo:
"""Detect document layout regions.
Args:
image: Input image as numpy array in BGR format.
image_size: Image size for prediction.
Returns:
LayoutInfo with detected regions.
Raises:
RuntimeError: If model not loaded.
"""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
# Run prediction
results = self.model.predict(
image,
imgsz=image_size,
conf=self.confidence_threshold,
device="cuda:0",
)
regions: list[LayoutRegion] = []
has_plain_text = False
has_formula = False
if results and len(results) > 0:
result = results[0]
if result.boxes is not None:
for box in result.boxes:
cls_id = int(box.cls[0].item())
confidence = float(box.conf[0].item())
bbox = box.xyxy[0].tolist()
class_name = self.CLASS_NAMES.get(cls_id, f"unknown_{cls_id}")
# Map to simplified type
if class_name in self.PLAIN_TEXT_CLASSES:
region_type = "text"
has_plain_text = True
elif class_name in self.FORMULA_CLASSES:
region_type = "formula"
has_formula = True
elif class_name in {"figure"}:
region_type = "figure"
elif class_name in {"table"}:
region_type = "table"
else:
region_type = class_name
regions.append(
LayoutRegion(
type=region_type,
bbox=bbox,
confidence=confidence,
)
)
return LayoutInfo(
regions=regions,
has_plain_text=has_plain_text,
has_formula=has_formula,
)