123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
"""DocLayout-YOLO wrapper for document layout detection."""
|
|
|
|
import numpy as np
|
|
|
|
from app.schemas.image import LayoutInfo, LayoutRegion
|
|
from app.core.config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class LayoutDetector:
|
|
"""Wrapper for DocLayout-YOLO model."""
|
|
|
|
# Class names from DocLayout-YOLO
|
|
CLASS_NAMES = {
|
|
0: "title",
|
|
1: "plain_text",
|
|
2: "abandon",
|
|
3: "figure",
|
|
4: "figure_caption",
|
|
5: "table",
|
|
6: "table_caption",
|
|
7: "table_footnote",
|
|
8: "isolate_formula",
|
|
9: "formula_caption",
|
|
}
|
|
|
|
# Classes considered as plain text
|
|
PLAIN_TEXT_CLASSES = {"title", "plain_text", "figure_caption", "table_caption", "table_footnote"}
|
|
|
|
# Classes considered as formula
|
|
FORMULA_CLASSES = {"isolate_formula", "formula_caption"}
|
|
|
|
def __init__(self, model_path: str, confidence_threshold: float = 0.2):
|
|
"""Initialize the layout detector.
|
|
|
|
Args:
|
|
model_path: Path to the DocLayout-YOLO model weights.
|
|
confidence_threshold: Minimum confidence for detections.
|
|
"""
|
|
self.model_path = model_path
|
|
self.confidence_threshold = confidence_threshold
|
|
self.model = None
|
|
|
|
def load_model(self) -> None:
|
|
"""Load the DocLayout-YOLO model.
|
|
|
|
Raises:
|
|
RuntimeError: If model cannot be loaded.
|
|
"""
|
|
try:
|
|
from doclayout_yolo import YOLOv10
|
|
|
|
self.model = YOLOv10(self.model_path)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load DocLayout-YOLO model: {e}") from e
|
|
|
|
def detect(self, image: np.ndarray, image_size: int = 1024) -> LayoutInfo:
|
|
"""Detect document layout regions.
|
|
|
|
Args:
|
|
image: Input image as numpy array in BGR format.
|
|
image_size: Image size for prediction.
|
|
|
|
Returns:
|
|
LayoutInfo with detected regions.
|
|
|
|
Raises:
|
|
RuntimeError: If model not loaded.
|
|
"""
|
|
if self.model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
|
|
# Run prediction
|
|
results = self.model.predict(
|
|
image,
|
|
imgsz=image_size,
|
|
conf=self.confidence_threshold,
|
|
device=settings.device,
|
|
)
|
|
|
|
regions: list[LayoutRegion] = []
|
|
has_plain_text = False
|
|
has_formula = False
|
|
|
|
if results and len(results) > 0:
|
|
result = results[0]
|
|
if result.boxes is not None:
|
|
for box in result.boxes:
|
|
cls_id = int(box.cls[0].item())
|
|
confidence = float(box.conf[0].item())
|
|
bbox = box.xyxy[0].tolist()
|
|
|
|
class_name = self.CLASS_NAMES.get(cls_id, f"unknown_{cls_id}")
|
|
|
|
# Map to simplified type
|
|
if class_name in self.PLAIN_TEXT_CLASSES:
|
|
region_type = "text"
|
|
has_plain_text = True
|
|
elif class_name in self.FORMULA_CLASSES:
|
|
region_type = "formula"
|
|
has_formula = True
|
|
elif class_name in {"figure"}:
|
|
region_type = "figure"
|
|
elif class_name in {"table"}:
|
|
region_type = "table"
|
|
else:
|
|
region_type = class_name
|
|
|
|
regions.append(
|
|
LayoutRegion(
|
|
type=region_type,
|
|
bbox=bbox,
|
|
confidence=confidence,
|
|
)
|
|
)
|
|
|
|
return LayoutInfo(
|
|
regions=regions,
|
|
has_plain_text=has_plain_text,
|
|
has_formula=has_formula,
|
|
)
|