Files
doc_processer/app/services/glm_postprocess.py
2026-03-10 21:45:43 +08:00

431 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import logging
import re
from collections import Counter
from copy import deepcopy
from typing import Any
try:
from wordfreq import zipf_frequency
_WORDFREQ_AVAILABLE = True
except ImportError:
_WORDFREQ_AVAILABLE = False
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# result_postprocess_utils (ported)
# ---------------------------------------------------------------------------
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None:
"""Detect and truncate a consecutively-repeated pattern.
Returns the string with the repeat removed, or None if not found.
"""
n = len(s)
if n < min_unit_len * min_repeats:
return None
max_unit_len = n // min_repeats
if max_unit_len < min_unit_len:
return None
pattern = re.compile(
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
re.DOTALL,
)
match = pattern.search(s)
if match:
return s[: match.start()] + match.group(1)
return None
def clean_repeated_content(
content: str,
min_len: int = 10,
min_repeats: int = 10,
line_threshold: int = 10,
) -> str:
"""Remove hallucination-style repeated content (consecutive or line-level)."""
stripped = content.strip()
if not stripped:
return content
# 1. Consecutive repeat (multi-line aware)
if len(stripped) > min_len * min_repeats:
result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats)
if result is not None:
return result
# 2. Line-level repeat
lines = [line.strip() for line in content.split("\n") if line.strip()]
total_lines = len(lines)
if total_lines >= line_threshold and lines:
common, count = Counter(lines).most_common(1)[0]
if count >= line_threshold and (count / total_lines) >= 0.8:
for i, line in enumerate(lines):
if line == common:
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
if consecutive >= 3:
original_lines = content.split("\n")
non_empty_count = 0
for idx, orig_line in enumerate(original_lines):
if orig_line.strip():
non_empty_count += 1
if non_empty_count == i + 1:
return "\n".join(original_lines[: idx + 1])
break
return content
def clean_formula_number(number_content: str) -> str:
"""Strip delimiters from a formula number string, e.g. '(1)''1'.
Also strips math-mode delimiters ($$, $, \\[...\\]) that vLLM may add
when the region is processed with a formula prompt.
"""
s = number_content.strip()
# Strip display math delimiters
for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]:
if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end):
s = s[len(start) : -len(end)].strip()
break
# Strip CJK/ASCII parentheses
if s.startswith("(") and s.endswith(")"):
return s[1:-1]
if s.startswith("") and s.endswith(""):
return s[1:-1]
return s
# ---------------------------------------------------------------------------
# GLMResultFormatter
# ---------------------------------------------------------------------------
# Matches content that consists *entirely* of a display-math block and nothing else.
# Used to detect when a text/heading region was actually recognised as a formula by vLLM,
# so we can correct the label before heading prefixes (## …) are applied.
_PURE_DISPLAY_FORMULA_RE = re.compile(r"^\s*(?:\$\$[\s\S]+?\$\$|\\\[[\s\S]+?\\\])\s*$")
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
_LABEL_TO_CATEGORY: dict[str, str] = {
# text
"abstract": "text",
"algorithm": "text",
"content": "text",
"doc_title": "text",
"figure_title": "text",
"paragraph_title": "text",
"reference_content": "text",
"text": "text",
"vertical_text": "text",
"vision_footnote": "text",
"seal": "text",
"formula_number": "text",
# table
"table": "table",
# formula
"display_formula": "formula",
"inline_formula": "formula",
# image (skip OCR)
"chart": "image",
"image": "image",
}
class GLMResultFormatter:
"""Port of GLM-OCR's ResultFormatter for use in our pipeline.
Accepts a list of region dicts (each with label, native_label, content,
bbox_2d) and returns a final Markdown string.
"""
# ------------------------------------------------------------------ #
# Public entry-point
# ------------------------------------------------------------------ #
def process(self, regions: list[dict[str, Any]]) -> str:
"""Run the full postprocessing pipeline and return Markdown.
Args:
regions: List of dicts with keys:
- index (int) reading order from layout detection
- label (str) mapped category: text/formula/table/figure
- native_label (str) raw PP-DocLayout label (e.g. doc_title)
- content (str) raw OCR output from vLLM
- bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords
Returns:
Markdown string.
"""
# Sort by reading order
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
# Per-region cleaning + formatting
processed: list[dict] = []
for item in items:
item["native_label"] = item.get("native_label", item.get("label", "text"))
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
# Label correction: layout may say "text" (or a heading like "paragraph_title")
# but vLLM recognised the content as a formula and returned $$…$$. Without
# correction the heading prefix (##) would be prepended to the math block,
# producing broken output like "## $$ \mathbf{y}=… $$".
raw_content = (item.get("content") or "").strip()
if item["label"] == "text" and _PURE_DISPLAY_FORMULA_RE.match(raw_content):
logger.debug(
"Label corrected text (native=%s) → formula: pure display-formula detected",
item["native_label"],
)
item["label"] = "formula"
item["native_label"] = "display_formula"
item["content"] = self._format_content(
item.get("content") or "",
item["label"],
item["native_label"],
)
if not (item.get("content") or "").strip():
continue
processed.append(item)
# Re-index
for i, item in enumerate(processed):
item["index"] = i
# Structural merges
processed = self._merge_formula_numbers(processed)
processed = self._merge_text_blocks(processed)
processed = self._format_bullet_points(processed)
# Assemble Markdown
parts: list[str] = []
for item in processed:
content = item.get("content") or ""
if item["label"] == "image":
parts.append(f"![](bbox={item.get('bbox_2d', [])})")
elif content.strip():
parts.append(content)
return "\n\n".join(parts)
# ------------------------------------------------------------------ #
# Label mapping
# ------------------------------------------------------------------ #
def _map_label(self, label: str, native_label: str) -> str:
return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text"))
# ------------------------------------------------------------------ #
# Content cleaning
# ------------------------------------------------------------------ #
def _clean_content(self, content: str) -> str:
"""Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats."""
if content is None:
return ""
content = re.sub(r"^(\\t)+", "", content).lstrip()
content = re.sub(r"(\\t)+$", "", content).rstrip()
content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(·)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(_)\1{2,}", r"\1\1\1", content)
content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content)
if len(content) >= 2048:
content = clean_repeated_content(content)
return content.strip()
# ------------------------------------------------------------------ #
# Per-region content formatting
# ------------------------------------------------------------------ #
def _format_content(self, content: Any, label: str, native_label: str) -> str:
"""Clean and format a single region's content."""
if content is None:
return ""
content = self._clean_content(str(content))
# Heading formatting
if native_label == "doc_title":
content = re.sub(r"^#+\s*", "", content)
content = "# " + content
elif native_label == "paragraph_title":
if content.startswith("- ") or content.startswith("* "):
content = content[2:].lstrip()
content = re.sub(r"^#+\s*", "", content)
content = "## " + content.lstrip()
# Formula wrapping
if label == "formula":
content = content.strip()
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
if content.startswith(s):
content = content[len(s) :].strip()
if content.endswith(e):
content = content[: -len(e)].strip()
break
if not content:
logger.warning("Skipping formula region with empty content after stripping delimiters")
return ""
content = "$$\n" + content + "\n$$"
# Text formatting
if label == "text":
if content.startswith("·") or content.startswith("") or content.startswith("* "):
content = "- " + content[1:].lstrip()
match = re.match(r"^(\(|\)(\d+|[A-Za-z])(\)|\)(.*)$", content)
if match:
_, symbol, _, rest = match.groups()
content = f"({symbol}) {rest.lstrip()}"
match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\)(.*)$", content)
if match:
symbol, sep, rest = match.groups()
sep = ")" if sep == "" else sep
content = f"{symbol}{sep} {rest.lstrip()}"
# Single newline → double newline
content = re.sub(r"(?<!\n)\n(?!\n)", "\n\n", content)
return content
# ------------------------------------------------------------------ #
# Structural merges
# ------------------------------------------------------------------ #
def _merge_formula_numbers(self, items: list[dict]) -> list[dict]:
"""Merge formula_number region into adjacent formula with \\tag{}."""
if not items:
return items
merged: list[dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
native = block.get("native_label", "")
# Case 1: formula_number then formula
if native == "formula_number":
if i + 1 < len(items) and items[i + 1].get("label") == "formula":
num_clean = clean_formula_number(block.get("content", "").strip())
formula_content = items[i + 1].get("content", "")
merged_block = deepcopy(items[i + 1])
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue # always skip the formula_number block itself
# Case 2: formula then formula_number
if block.get("label") == "formula":
if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number":
num_clean = clean_formula_number(items[i + 1].get("content", "").strip())
formula_content = block.get("content", "")
merged_block = deepcopy(block)
if formula_content.endswith("\n$$"):
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
merged.append(merged_block)
skip.add(i + 1)
continue
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _merge_text_blocks(self, items: list[dict]) -> list[dict]:
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
if not items or not _WORDFREQ_AVAILABLE:
return items
merged: list[dict] = []
skip: set = set()
for i, block in enumerate(items):
if i in skip:
continue
if block.get("label") != "text":
merged.append(block)
continue
content = block.get("content", "")
if not isinstance(content, str) or not content.rstrip().endswith("-"):
merged.append(block)
continue
content_stripped = content.rstrip()
did_merge = False
for j in range(i + 1, len(items)):
if items[j].get("label") != "text":
continue
next_content = items[j].get("content", "")
if not isinstance(next_content, str):
continue
next_stripped = next_content.lstrip()
if next_stripped and next_stripped[0].islower():
words_before = content_stripped[:-1].split()
next_words = next_stripped.split()
if words_before and next_words:
merged_word = words_before[-1] + next_words[0]
if zipf_frequency(merged_word.lower(), "en") >= 2.5:
merged_block = deepcopy(block)
merged_block["content"] = content_stripped[:-1] + next_content.lstrip()
merged.append(merged_block)
skip.add(j)
did_merge = True
break
if not did_merge:
merged.append(block)
for i, block in enumerate(merged):
block["index"] = i
return merged
def _format_bullet_points(self, items: list[dict], left_align_threshold: float = 10.0) -> list[dict]:
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
if len(items) < 3:
return items
for i in range(1, len(items) - 1):
cur = items[i]
prev = items[i - 1]
nxt = items[i + 1]
if cur.get("native_label") != "text":
continue
if prev.get("native_label") != "text" or nxt.get("native_label") != "text":
continue
cur_content = cur.get("content", "")
if cur_content.startswith("- "):
continue
prev_content = prev.get("content", "")
nxt_content = nxt.get("content", "")
if not (prev_content.startswith("- ") and nxt_content.startswith("- ")):
continue
cur_bbox = cur.get("bbox_2d", [])
prev_bbox = prev.get("bbox_2d", [])
nxt_bbox = nxt.get("bbox_2d", [])
if not (cur_bbox and prev_bbox and nxt_bbox):
continue
if abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold:
cur["content"] = "- " + cur_content
return items