"""Layout post-processing utilities ported from GLM-OCR. Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py Algorithms applied after PaddleOCR LayoutDetection.predict(): 1. NMS with dual IoU thresholds (same-class vs cross-class) 2. Large-image-region filtering (remove image boxes that fill most of the page) 3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child) 4. Unclip ratio (optional bbox expansion) 5. Invalid bbox skipping These steps run on top of PaddleOCR's built-in detection to replicate the quality of the GLM-OCR SDK's layout pipeline. """ from __future__ import annotations import numpy as np # --------------------------------------------------------------------------- # Primitive geometry helpers # --------------------------------------------------------------------------- def iou(box1: list[float], box2: list[float]) -> float: """Compute IoU of two bounding boxes [x1, y1, x2, y2].""" x1, y1, x2, y2 = box1 x1_p, y1_p, x2_p, y2_p = box2 x1_i = max(x1, x1_p) y1_i = max(y1, y1_p) x2_i = min(x2, x2_p) y2_i = min(y2, y2_p) inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1) box1_area = (x2 - x1 + 1) * (y2 - y1 + 1) box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1) return inter_area / float(box1_area + box2_area - inter_area) def is_contained(box1: list[float], box2: list[float], overlap_threshold: float = 0.8) -> bool: """Return True if box1 is contained within box2 (overlap ratio >= threshold). box format: [cls_id, score, x1, y1, x2, y2] """ _, _, x1, y1, x2, y2 = box1 _, _, x1_p, y1_p, x2_p, y2_p = box2 box1_area = (x2 - x1) * (y2 - y1) if box1_area <= 0: return False xi1 = max(x1, x1_p) yi1 = max(y1, y1_p) xi2 = min(x2, x2_p) yi2 = min(y2, y2_p) inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) return (inter_area / box1_area) >= overlap_threshold # --------------------------------------------------------------------------- # NMS # --------------------------------------------------------------------------- def nms( boxes: np.ndarray, iou_same: float = 0.6, iou_diff: float = 0.98, ) -> list[int]: """NMS with separate IoU thresholds for same-class and cross-class overlaps. Args: boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. iou_same: Suppression threshold for boxes of the same class. iou_diff: Suppression threshold for boxes of different classes. Returns: List of kept row indices. """ scores = boxes[:, 1] indices = np.argsort(scores)[::-1].tolist() selected: list[int] = [] while indices: current = indices[0] selected.append(current) current_class = int(boxes[current, 0]) current_coords = boxes[current, 2:6].tolist() indices = indices[1:] kept = [] for i in indices: box_class = int(boxes[i, 0]) box_coords = boxes[i, 2:6].tolist() threshold = iou_same if current_class == box_class else iou_diff if iou(current_coords, box_coords) < threshold: kept.append(i) indices = kept return selected # --------------------------------------------------------------------------- # Containment analysis # --------------------------------------------------------------------------- # Labels whose regions should never be removed even when contained in another box _PRESERVE_LABELS = {"image", "seal", "chart"} def check_containment( boxes: np.ndarray, preserve_cls_ids: set | None = None, category_index: int | None = None, mode: str | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Compute containment flags for each box. Args: boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. preserve_cls_ids: Class IDs that must never be marked as contained. category_index: If set, apply mode only relative to this class. mode: 'large' or 'small' (only used with category_index). Returns: (contains_other, contained_by_other): boolean arrays of length N. """ n = len(boxes) contains_other = np.zeros(n, dtype=int) contained_by_other = np.zeros(n, dtype=int) for i in range(n): for j in range(n): if i == j: continue if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids: continue if category_index is not None and mode is not None: if mode == "large" and int(boxes[j, 0]) == category_index: if is_contained(boxes[i].tolist(), boxes[j].tolist()): contained_by_other[i] = 1 contains_other[j] = 1 elif mode == "small" and int(boxes[i, 0]) == category_index: if is_contained(boxes[i].tolist(), boxes[j].tolist()): contained_by_other[i] = 1 contains_other[j] = 1 else: if is_contained(boxes[i].tolist(), boxes[j].tolist()): contained_by_other[i] = 1 contains_other[j] = 1 return contains_other, contained_by_other # --------------------------------------------------------------------------- # Box expansion (unclip) # --------------------------------------------------------------------------- def unclip_boxes( boxes: np.ndarray, unclip_ratio: float | tuple[float, float] | dict | list | None, ) -> np.ndarray: """Expand bounding boxes by the given ratio. Args: boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...]. unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio. Returns: Expanded boxes array. """ if unclip_ratio is None: return boxes if isinstance(unclip_ratio, dict): expanded = [] for box in boxes: cls_id = int(box[0]) if cls_id in unclip_ratio: w_ratio, h_ratio = unclip_ratio[cls_id] x1, y1, x2, y2 = box[2], box[3], box[4], box[5] cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio new_box = list(box) new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2 new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2 expanded.append(new_box) else: expanded.append(list(box)) return np.array(expanded) # Scalar or tuple if isinstance(unclip_ratio, (int, float)): unclip_ratio = (float(unclip_ratio), float(unclip_ratio)) w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1] widths = boxes[:, 4] - boxes[:, 2] heights = boxes[:, 5] - boxes[:, 3] cx = boxes[:, 2] + widths / 2 cy = boxes[:, 3] + heights / 2 nw, nh = widths * w_ratio, heights * h_ratio expanded = boxes.copy().astype(float) expanded[:, 2] = cx - nw / 2 expanded[:, 3] = cy - nh / 2 expanded[:, 4] = cx + nw / 2 expanded[:, 5] = cy + nh / 2 return expanded # --------------------------------------------------------------------------- # Main entry-point # --------------------------------------------------------------------------- def apply_layout_postprocess( boxes: list[dict], img_size: tuple[int, int], layout_nms: bool = True, layout_unclip_ratio: float | tuple | dict | None = None, layout_merge_bboxes_mode: str | dict | None = "large", ) -> list[dict]: """Apply GLM-OCR layout post-processing to PaddleOCR detection results. Args: boxes: PaddleOCR output — list of dicts with keys: cls_id, label, score, coordinate ([x1, y1, x2, y2]). img_size: (width, height) of the image. layout_nms: Apply dual-threshold NMS. layout_unclip_ratio: Optional bbox expansion ratio. layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small', 'union', or per-class dict. Returns: Filtered and ordered list of box dicts in the same PaddleOCR format. """ if not boxes: return boxes img_width, img_height = img_size # --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- # arr_rows = [] for b in boxes: cls_id = b.get("cls_id", 0) score = b.get("score", 0.0) x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0]) arr_rows.append([cls_id, score, x1, y1, x2, y2]) boxes_array = np.array(arr_rows, dtype=float) all_labels: list[str] = [b.get("label", "") for b in boxes] # 1. NMS ---------------------------------------------------------------- # if layout_nms and len(boxes_array) > 1: kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98) boxes_array = boxes_array[kept] all_labels = [all_labels[k] for k in kept] # 2. Filter large image regions ---------------------------------------- # if len(boxes_array) > 1: img_area = img_width * img_height area_thres = 0.82 if img_width > img_height else 0.93 keep_mask = np.ones(len(boxes_array), dtype=bool) for i, lbl in enumerate(all_labels): if lbl == "image": x1, y1, x2, y2 = boxes_array[i, 2:6] x1 = max(0.0, x1) y1 = max(0.0, y1) x2 = min(float(img_width), x2) y2 = min(float(img_height), y2) if (x2 - x1) * (y2 - y1) > area_thres * img_area: keep_mask[i] = False boxes_array = boxes_array[keep_mask] all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] # 3. Containment analysis (merge_bboxes_mode) -------------------------- # if layout_merge_bboxes_mode and len(boxes_array) > 1: preserve_cls_ids = { int(boxes_array[i, 0]) for i, lbl in enumerate(all_labels) if lbl in _PRESERVE_LABELS } if isinstance(layout_merge_bboxes_mode, str): mode = layout_merge_bboxes_mode if mode in ("large", "small"): contains_other, contained_by_other = check_containment( boxes_array, preserve_cls_ids ) if mode == "large": keep_mask = contained_by_other == 0 else: keep_mask = (contains_other == 0) | (contained_by_other == 1) boxes_array = boxes_array[keep_mask] all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] elif isinstance(layout_merge_bboxes_mode, dict): keep_mask = np.ones(len(boxes_array), dtype=bool) for category_index, mode in layout_merge_bboxes_mode.items(): if mode in ("large", "small"): contains_other, contained_by_other = check_containment( boxes_array, preserve_cls_ids, int(category_index), mode ) if mode == "large": keep_mask &= contained_by_other == 0 else: keep_mask &= (contains_other == 0) | (contained_by_other == 1) boxes_array = boxes_array[keep_mask] all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k] if len(boxes_array) == 0: return [] # 4. Unclip (bbox expansion) ------------------------------------------- # if layout_unclip_ratio is not None: boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio) # 5. Clamp to image boundaries + skip invalid -------------------------- # result: list[dict] = [] for i, row in enumerate(boxes_array): cls_id = int(row[0]) score = float(row[1]) x1 = max(0.0, min(float(row[2]), img_width)) y1 = max(0.0, min(float(row[3]), img_height)) x2 = max(0.0, min(float(row[4]), img_width)) y2 = max(0.0, min(float(row[5]), img_height)) if x1 >= x2 or y1 >= y2: continue result.append( { "cls_id": cls_id, "label": all_labels[i], "score": score, "coordinate": [int(x1), int(y1), int(x2), int(y2)], } ) return result