"""DocLayout-YOLO wrapper for document layout detection.""" import numpy as np from app.schemas.image import LayoutInfo, LayoutRegion from app.core.config import get_settings settings = get_settings() class LayoutDetector: """Wrapper for DocLayout-YOLO model.""" # Class names from DocLayout-YOLO CLASS_NAMES = { 0: "title", 1: "plain_text", 2: "abandon", 3: "figure", 4: "figure_caption", 5: "table", 6: "table_caption", 7: "table_footnote", 8: "isolate_formula", 9: "formula_caption", } # Classes considered as plain text PLAIN_TEXT_CLASSES = {"title", "plain_text", "figure_caption", "table_caption", "table_footnote"} # Classes considered as formula FORMULA_CLASSES = {"isolate_formula", "formula_caption"} def __init__(self, model_path: str, confidence_threshold: float = 0.2): """Initialize the layout detector. Args: model_path: Path to the DocLayout-YOLO model weights. confidence_threshold: Minimum confidence for detections. """ self.model_path = model_path self.confidence_threshold = confidence_threshold self.model = None def load_model(self) -> None: """Load the DocLayout-YOLO model. Raises: RuntimeError: If model cannot be loaded. """ try: from doclayout_yolo import YOLOv10 self.model = YOLOv10(self.model_path) except Exception as e: raise RuntimeError(f"Failed to load DocLayout-YOLO model: {e}") from e def detect(self, image: np.ndarray, image_size: int = 1024) -> LayoutInfo: """Detect document layout regions. Args: image: Input image as numpy array in BGR format. image_size: Image size for prediction. Returns: LayoutInfo with detected regions. Raises: RuntimeError: If model not loaded. """ if self.model is None: raise RuntimeError("Model not loaded. Call load_model() first.") # Run prediction results = self.model.predict( image, imgsz=image_size, conf=self.confidence_threshold, device=settings.device, ) regions: list[LayoutRegion] = [] has_plain_text = False has_formula = False if results and len(results) > 0: result = results[0] if result.boxes is not None: for box in result.boxes: cls_id = int(box.cls[0].item()) confidence = float(box.conf[0].item()) bbox = box.xyxy[0].tolist() class_name = self.CLASS_NAMES.get(cls_id, f"unknown_{cls_id}") # Map to simplified type if class_name in self.PLAIN_TEXT_CLASSES: region_type = "text" has_plain_text = True elif class_name in self.FORMULA_CLASSES: region_type = "formula" has_formula = True elif class_name in {"figure"}: region_type = "figure" elif class_name in {"table"}: region_type = "table" else: region_type = class_name regions.append( LayoutRegion( type=region_type, bbox=bbox, confidence=confidence, ) ) return LayoutInfo( regions=regions, has_plain_text=has_plain_text, has_formula=has_formula, )