344 lines
12 KiB
Python
344 lines
12 KiB
Python
|
|
"""Layout post-processing utilities ported from GLM-OCR.
|
||
|
|
|
||
|
|
Source: glm-ocr/glmocr/utils/layout_postprocess_utils.py
|
||
|
|
|
||
|
|
Algorithms applied after PaddleOCR LayoutDetection.predict():
|
||
|
|
1. NMS with dual IoU thresholds (same-class vs cross-class)
|
||
|
|
2. Large-image-region filtering (remove image boxes that fill most of the page)
|
||
|
|
3. Containment analysis (merge_bboxes_mode: keep large parent, remove contained child)
|
||
|
|
4. Unclip ratio (optional bbox expansion)
|
||
|
|
5. Invalid bbox skipping
|
||
|
|
|
||
|
|
These steps run on top of PaddleOCR's built-in detection to replicate
|
||
|
|
the quality of the GLM-OCR SDK's layout pipeline.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Dict, List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Primitive geometry helpers
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def iou(box1: List[float], box2: List[float]) -> float:
|
||
|
|
"""Compute IoU of two bounding boxes [x1, y1, x2, y2]."""
|
||
|
|
x1, y1, x2, y2 = box1
|
||
|
|
x1_p, y1_p, x2_p, y2_p = box2
|
||
|
|
|
||
|
|
x1_i = max(x1, x1_p)
|
||
|
|
y1_i = max(y1, y1_p)
|
||
|
|
x2_i = min(x2, x2_p)
|
||
|
|
y2_i = min(y2, y2_p)
|
||
|
|
|
||
|
|
inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
|
||
|
|
box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||
|
|
box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
|
||
|
|
|
||
|
|
return inter_area / float(box1_area + box2_area - inter_area)
|
||
|
|
|
||
|
|
|
||
|
|
def is_contained(box1: List[float], box2: List[float], overlap_threshold: float = 0.8) -> bool:
|
||
|
|
"""Return True if box1 is contained within box2 (overlap ratio >= threshold).
|
||
|
|
|
||
|
|
box format: [cls_id, score, x1, y1, x2, y2]
|
||
|
|
"""
|
||
|
|
_, _, x1, y1, x2, y2 = box1
|
||
|
|
_, _, x1_p, y1_p, x2_p, y2_p = box2
|
||
|
|
|
||
|
|
box1_area = (x2 - x1) * (y2 - y1)
|
||
|
|
if box1_area <= 0:
|
||
|
|
return False
|
||
|
|
|
||
|
|
xi1 = max(x1, x1_p)
|
||
|
|
yi1 = max(y1, y1_p)
|
||
|
|
xi2 = min(x2, x2_p)
|
||
|
|
yi2 = min(y2, y2_p)
|
||
|
|
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||
|
|
|
||
|
|
return (inter_area / box1_area) >= overlap_threshold
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# NMS
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def nms(
|
||
|
|
boxes: np.ndarray,
|
||
|
|
iou_same: float = 0.6,
|
||
|
|
iou_diff: float = 0.98,
|
||
|
|
) -> List[int]:
|
||
|
|
"""NMS with separate IoU thresholds for same-class and cross-class overlaps.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||
|
|
iou_same: Suppression threshold for boxes of the same class.
|
||
|
|
iou_diff: Suppression threshold for boxes of different classes.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of kept row indices.
|
||
|
|
"""
|
||
|
|
scores = boxes[:, 1]
|
||
|
|
indices = np.argsort(scores)[::-1].tolist()
|
||
|
|
selected: List[int] = []
|
||
|
|
|
||
|
|
while indices:
|
||
|
|
current = indices[0]
|
||
|
|
selected.append(current)
|
||
|
|
current_class = int(boxes[current, 0])
|
||
|
|
current_coords = boxes[current, 2:6].tolist()
|
||
|
|
indices = indices[1:]
|
||
|
|
|
||
|
|
kept = []
|
||
|
|
for i in indices:
|
||
|
|
box_class = int(boxes[i, 0])
|
||
|
|
box_coords = boxes[i, 2:6].tolist()
|
||
|
|
threshold = iou_same if current_class == box_class else iou_diff
|
||
|
|
if iou(current_coords, box_coords) < threshold:
|
||
|
|
kept.append(i)
|
||
|
|
indices = kept
|
||
|
|
|
||
|
|
return selected
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Containment analysis
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
# Labels whose regions should never be removed even when contained in another box
|
||
|
|
_PRESERVE_LABELS = {"image", "seal", "chart"}
|
||
|
|
|
||
|
|
|
||
|
|
def check_containment(
|
||
|
|
boxes: np.ndarray,
|
||
|
|
preserve_cls_ids: Optional[set] = None,
|
||
|
|
category_index: Optional[int] = None,
|
||
|
|
mode: Optional[str] = None,
|
||
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||
|
|
"""Compute containment flags for each box.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||
|
|
preserve_cls_ids: Class IDs that must never be marked as contained.
|
||
|
|
category_index: If set, apply mode only relative to this class.
|
||
|
|
mode: 'large' or 'small' (only used with category_index).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(contains_other, contained_by_other): boolean arrays of length N.
|
||
|
|
"""
|
||
|
|
n = len(boxes)
|
||
|
|
contains_other = np.zeros(n, dtype=int)
|
||
|
|
contained_by_other = np.zeros(n, dtype=int)
|
||
|
|
|
||
|
|
for i in range(n):
|
||
|
|
for j in range(n):
|
||
|
|
if i == j:
|
||
|
|
continue
|
||
|
|
if preserve_cls_ids and int(boxes[i, 0]) in preserve_cls_ids:
|
||
|
|
continue
|
||
|
|
if category_index is not None and mode is not None:
|
||
|
|
if mode == "large" and int(boxes[j, 0]) == category_index:
|
||
|
|
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||
|
|
contained_by_other[i] = 1
|
||
|
|
contains_other[j] = 1
|
||
|
|
elif mode == "small" and int(boxes[i, 0]) == category_index:
|
||
|
|
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||
|
|
contained_by_other[i] = 1
|
||
|
|
contains_other[j] = 1
|
||
|
|
else:
|
||
|
|
if is_contained(boxes[i].tolist(), boxes[j].tolist()):
|
||
|
|
contained_by_other[i] = 1
|
||
|
|
contains_other[j] = 1
|
||
|
|
|
||
|
|
return contains_other, contained_by_other
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Box expansion (unclip)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def unclip_boxes(
|
||
|
|
boxes: np.ndarray,
|
||
|
|
unclip_ratio: Union[float, Tuple[float, float], Dict, List, None],
|
||
|
|
) -> np.ndarray:
|
||
|
|
"""Expand bounding boxes by the given ratio.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
boxes: Array of shape (N, 6+) — [cls_id, score, x1, y1, x2, y2, ...].
|
||
|
|
unclip_ratio: Scalar, (w_ratio, h_ratio) tuple, or dict mapping cls_id to ratio.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Expanded boxes array.
|
||
|
|
"""
|
||
|
|
if unclip_ratio is None:
|
||
|
|
return boxes
|
||
|
|
|
||
|
|
if isinstance(unclip_ratio, dict):
|
||
|
|
expanded = []
|
||
|
|
for box in boxes:
|
||
|
|
cls_id = int(box[0])
|
||
|
|
if cls_id in unclip_ratio:
|
||
|
|
w_ratio, h_ratio = unclip_ratio[cls_id]
|
||
|
|
x1, y1, x2, y2 = box[2], box[3], box[4], box[5]
|
||
|
|
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||
|
|
nw, nh = (x2 - x1) * w_ratio, (y2 - y1) * h_ratio
|
||
|
|
new_box = list(box)
|
||
|
|
new_box[2], new_box[3] = cx - nw / 2, cy - nh / 2
|
||
|
|
new_box[4], new_box[5] = cx + nw / 2, cy + nh / 2
|
||
|
|
expanded.append(new_box)
|
||
|
|
else:
|
||
|
|
expanded.append(list(box))
|
||
|
|
return np.array(expanded)
|
||
|
|
|
||
|
|
# Scalar or tuple
|
||
|
|
if isinstance(unclip_ratio, (int, float)):
|
||
|
|
unclip_ratio = (float(unclip_ratio), float(unclip_ratio))
|
||
|
|
|
||
|
|
w_ratio, h_ratio = unclip_ratio[0], unclip_ratio[1]
|
||
|
|
widths = boxes[:, 4] - boxes[:, 2]
|
||
|
|
heights = boxes[:, 5] - boxes[:, 3]
|
||
|
|
cx = boxes[:, 2] + widths / 2
|
||
|
|
cy = boxes[:, 3] + heights / 2
|
||
|
|
nw, nh = widths * w_ratio, heights * h_ratio
|
||
|
|
expanded = boxes.copy().astype(float)
|
||
|
|
expanded[:, 2] = cx - nw / 2
|
||
|
|
expanded[:, 3] = cy - nh / 2
|
||
|
|
expanded[:, 4] = cx + nw / 2
|
||
|
|
expanded[:, 5] = cy + nh / 2
|
||
|
|
return expanded
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Main entry-point
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def apply_layout_postprocess(
|
||
|
|
boxes: List[Dict],
|
||
|
|
img_size: Tuple[int, int],
|
||
|
|
layout_nms: bool = True,
|
||
|
|
layout_unclip_ratio: Union[float, Tuple, Dict, None] = None,
|
||
|
|
layout_merge_bboxes_mode: Union[str, Dict, None] = "large",
|
||
|
|
) -> List[Dict]:
|
||
|
|
"""Apply GLM-OCR layout post-processing to PaddleOCR detection results.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
boxes: PaddleOCR output — list of dicts with keys:
|
||
|
|
cls_id, label, score, coordinate ([x1, y1, x2, y2]).
|
||
|
|
img_size: (width, height) of the image.
|
||
|
|
layout_nms: Apply dual-threshold NMS.
|
||
|
|
layout_unclip_ratio: Optional bbox expansion ratio.
|
||
|
|
layout_merge_bboxes_mode: Containment mode — 'large' (default), 'small',
|
||
|
|
'union', or per-class dict.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Filtered and ordered list of box dicts in the same PaddleOCR format.
|
||
|
|
"""
|
||
|
|
if not boxes:
|
||
|
|
return boxes
|
||
|
|
|
||
|
|
img_width, img_height = img_size
|
||
|
|
|
||
|
|
# --- Build working array [cls_id, score, x1, y1, x2, y2] -------------- #
|
||
|
|
arr_rows = []
|
||
|
|
for b in boxes:
|
||
|
|
cls_id = b.get("cls_id", 0)
|
||
|
|
score = b.get("score", 0.0)
|
||
|
|
x1, y1, x2, y2 = b.get("coordinate", [0, 0, 0, 0])
|
||
|
|
arr_rows.append([cls_id, score, x1, y1, x2, y2])
|
||
|
|
boxes_array = np.array(arr_rows, dtype=float)
|
||
|
|
|
||
|
|
all_labels: List[str] = [b.get("label", "") for b in boxes]
|
||
|
|
|
||
|
|
# 1. NMS ---------------------------------------------------------------- #
|
||
|
|
if layout_nms and len(boxes_array) > 1:
|
||
|
|
kept = nms(boxes_array, iou_same=0.6, iou_diff=0.98)
|
||
|
|
boxes_array = boxes_array[kept]
|
||
|
|
all_labels = [all_labels[k] for k in kept]
|
||
|
|
|
||
|
|
# 2. Filter large image regions ---------------------------------------- #
|
||
|
|
if len(boxes_array) > 1:
|
||
|
|
img_area = img_width * img_height
|
||
|
|
area_thres = 0.82 if img_width > img_height else 0.93
|
||
|
|
image_cls_ids = {
|
||
|
|
int(boxes_array[i, 0])
|
||
|
|
for i, lbl in enumerate(all_labels)
|
||
|
|
if lbl == "image"
|
||
|
|
}
|
||
|
|
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||
|
|
for i, lbl in enumerate(all_labels):
|
||
|
|
if lbl == "image":
|
||
|
|
x1, y1, x2, y2 = boxes_array[i, 2:6]
|
||
|
|
x1 = max(0.0, x1); y1 = max(0.0, y1)
|
||
|
|
x2 = min(float(img_width), x2); y2 = min(float(img_height), y2)
|
||
|
|
if (x2 - x1) * (y2 - y1) > area_thres * img_area:
|
||
|
|
keep_mask[i] = False
|
||
|
|
boxes_array = boxes_array[keep_mask]
|
||
|
|
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||
|
|
|
||
|
|
# 3. Containment analysis (merge_bboxes_mode) -------------------------- #
|
||
|
|
if layout_merge_bboxes_mode and len(boxes_array) > 1:
|
||
|
|
preserve_cls_ids = {
|
||
|
|
int(boxes_array[i, 0])
|
||
|
|
for i, lbl in enumerate(all_labels)
|
||
|
|
if lbl in _PRESERVE_LABELS
|
||
|
|
}
|
||
|
|
|
||
|
|
if isinstance(layout_merge_bboxes_mode, str):
|
||
|
|
mode = layout_merge_bboxes_mode
|
||
|
|
if mode in ("large", "small"):
|
||
|
|
contains_other, contained_by_other = check_containment(
|
||
|
|
boxes_array, preserve_cls_ids
|
||
|
|
)
|
||
|
|
if mode == "large":
|
||
|
|
keep_mask = contained_by_other == 0
|
||
|
|
else:
|
||
|
|
keep_mask = (contains_other == 0) | (contained_by_other == 1)
|
||
|
|
boxes_array = boxes_array[keep_mask]
|
||
|
|
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||
|
|
|
||
|
|
elif isinstance(layout_merge_bboxes_mode, dict):
|
||
|
|
keep_mask = np.ones(len(boxes_array), dtype=bool)
|
||
|
|
for category_index, mode in layout_merge_bboxes_mode.items():
|
||
|
|
if mode in ("large", "small"):
|
||
|
|
contains_other, contained_by_other = check_containment(
|
||
|
|
boxes_array, preserve_cls_ids, int(category_index), mode
|
||
|
|
)
|
||
|
|
if mode == "large":
|
||
|
|
keep_mask &= contained_by_other == 0
|
||
|
|
else:
|
||
|
|
keep_mask &= (contains_other == 0) | (contained_by_other == 1)
|
||
|
|
boxes_array = boxes_array[keep_mask]
|
||
|
|
all_labels = [lbl for lbl, k in zip(all_labels, keep_mask) if k]
|
||
|
|
|
||
|
|
if len(boxes_array) == 0:
|
||
|
|
return []
|
||
|
|
|
||
|
|
# 4. Unclip (bbox expansion) ------------------------------------------- #
|
||
|
|
if layout_unclip_ratio is not None:
|
||
|
|
boxes_array = unclip_boxes(boxes_array, layout_unclip_ratio)
|
||
|
|
|
||
|
|
# 5. Clamp to image boundaries + skip invalid -------------------------- #
|
||
|
|
result: List[Dict] = []
|
||
|
|
for i, row in enumerate(boxes_array):
|
||
|
|
cls_id = int(row[0])
|
||
|
|
score = float(row[1])
|
||
|
|
x1 = max(0.0, min(float(row[2]), img_width))
|
||
|
|
y1 = max(0.0, min(float(row[3]), img_height))
|
||
|
|
x2 = max(0.0, min(float(row[4]), img_width))
|
||
|
|
y2 = max(0.0, min(float(row[5]), img_height))
|
||
|
|
|
||
|
|
if x1 >= x2 or y1 >= y2:
|
||
|
|
continue
|
||
|
|
|
||
|
|
result.append({
|
||
|
|
"cls_id": cls_id,
|
||
|
|
"label": all_labels[i],
|
||
|
|
"score": score,
|
||
|
|
"coordinate": [int(x1), int(y1), int(x2), int(y2)],
|
||
|
|
})
|
||
|
|
|
||
|
|
return result
|