"""GLM-OCR postprocessing logic adapted for this project. Ported from glm-ocr/glmocr/postprocess/result_formatter.py and glm-ocr/glmocr/utils/result_postprocess_utils.py. Covers: - Repeated-content / hallucination detection - Per-region content cleaning and formatting (titles, bullets, formulas) - formula_number merging (→ \\tag{}) - Hyphenated text-block merging (via wordfreq) - Missing bullet-point detection """ from __future__ import annotations import logging import re import json logger = logging.getLogger(__name__) from collections import Counter from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple try: from wordfreq import zipf_frequency _WORDFREQ_AVAILABLE = True except ImportError: _WORDFREQ_AVAILABLE = False # --------------------------------------------------------------------------- # result_postprocess_utils (ported) # --------------------------------------------------------------------------- def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]: """Detect and truncate a consecutively-repeated pattern. Returns the string with the repeat removed, or None if not found. """ n = len(s) if n < min_unit_len * min_repeats: return None max_unit_len = n // min_repeats if max_unit_len < min_unit_len: return None pattern = re.compile( r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}", re.DOTALL, ) match = pattern.search(s) if match: return s[: match.start()] + match.group(1) return None def clean_repeated_content( content: str, min_len: int = 10, min_repeats: int = 10, line_threshold: int = 10, ) -> str: """Remove hallucination-style repeated content (consecutive or line-level).""" stripped = content.strip() if not stripped: return content # 1. Consecutive repeat (multi-line aware) if len(stripped) > min_len * min_repeats: result = find_consecutive_repeat(stripped, min_unit_len=min_len, min_repeats=min_repeats) if result is not None: return result # 2. Line-level repeat lines = [line.strip() for line in content.split("\n") if line.strip()] total_lines = len(lines) if total_lines >= line_threshold and lines: common, count = Counter(lines).most_common(1)[0] if count >= line_threshold and (count / total_lines) >= 0.8: for i, line in enumerate(lines): if line == common: consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common) if consecutive >= 3: original_lines = content.split("\n") non_empty_count = 0 for idx, orig_line in enumerate(original_lines): if orig_line.strip(): non_empty_count += 1 if non_empty_count == i + 1: return "\n".join(original_lines[: idx + 1]) break return content def clean_formula_number(number_content: str) -> str: """Strip delimiters from a formula number string, e.g. '(1)' → '1'. Also strips math-mode delimiters ($$, $, \\[...\\]) that vLLM may add when the region is processed with a formula prompt. """ s = number_content.strip() # Strip display math delimiters for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]: if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end): s = s[len(start):-len(end)].strip() break # Strip CJK/ASCII parentheses if s.startswith("(") and s.endswith(")"): return s[1:-1] if s.startswith("(") and s.endswith(")"): return s[1:-1] return s # --------------------------------------------------------------------------- # GLMResultFormatter # --------------------------------------------------------------------------- # Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping) _LABEL_TO_CATEGORY: Dict[str, str] = { # text "abstract": "text", "algorithm": "text", "content": "text", "doc_title": "text", "figure_title": "text", "paragraph_title": "text", "reference_content": "text", "text": "text", "vertical_text": "text", "vision_footnote": "text", "seal": "text", "formula_number": "text", # table "table": "table", # formula "display_formula": "formula", "inline_formula": "formula", # image (skip OCR) "chart": "image", "image": "image", } class GLMResultFormatter: """Port of GLM-OCR's ResultFormatter for use in our pipeline. Accepts a list of region dicts (each with label, native_label, content, bbox_2d) and returns a final Markdown string. """ # ------------------------------------------------------------------ # # Public entry-point # ------------------------------------------------------------------ # def process(self, regions: List[Dict[str, Any]]) -> str: """Run the full postprocessing pipeline and return Markdown. Args: regions: List of dicts with keys: - index (int) reading order from layout detection - label (str) mapped category: text/formula/table/figure - native_label (str) raw PP-DocLayout label (e.g. doc_title) - content (str) raw OCR output from vLLM - bbox_2d (list) [x1, y1, x2, y2] in 0-1000 normalised coords Returns: Markdown string. """ # Sort by reading order items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0)) # Per-region cleaning + formatting processed: List[Dict] = [] for item in items: item["native_label"] = item.get("native_label", item.get("label", "text")) item["label"] = self._map_label(item.get("label", "text"), item["native_label"]) item["content"] = self._format_content( item.get("content") or "", item["label"], item["native_label"], ) if not (item.get("content") or "").strip(): continue processed.append(item) # Re-index for i, item in enumerate(processed): item["index"] = i # Structural merges processed = self._merge_formula_numbers(processed) processed = self._merge_text_blocks(processed) processed = self._format_bullet_points(processed) # Assemble Markdown parts: List[str] = [] for item in processed: content = item.get("content") or "" if item["label"] == "image": parts.append(f"![](bbox={item.get('bbox_2d', [])})") elif content.strip(): parts.append(content) return "\n\n".join(parts) # ------------------------------------------------------------------ # # Label mapping # ------------------------------------------------------------------ # def _map_label(self, label: str, native_label: str) -> str: return _LABEL_TO_CATEGORY.get(native_label, _LABEL_TO_CATEGORY.get(label, "text")) # ------------------------------------------------------------------ # # Content cleaning # ------------------------------------------------------------------ # def _clean_content(self, content: str) -> str: """Remove artefacts: leading/trailing \\t, repeated punctuation, long repeats.""" if content is None: return "" content = re.sub(r"^(\\t)+", "", content).lstrip() content = re.sub(r"(\\t)+$", "", content).rstrip() content = re.sub(r"(\.)\1{2,}", r"\1\1\1", content) content = re.sub(r"(·)\1{2,}", r"\1\1\1", content) content = re.sub(r"(_)\1{2,}", r"\1\1\1", content) content = re.sub(r"(\\_)\1{2,}", r"\1\1\1", content) if len(content) >= 2048: content = clean_repeated_content(content) return content.strip() # ------------------------------------------------------------------ # # Per-region content formatting # ------------------------------------------------------------------ # def _format_content(self, content: Any, label: str, native_label: str) -> str: """Clean and format a single region's content.""" if content is None: return "" content = self._clean_content(str(content)) # Heading formatting if native_label == "doc_title": content = re.sub(r"^#+\s*", "", content) content = "# " + content elif native_label == "paragraph_title": if content.startswith("- ") or content.startswith("* "): content = content[2:].lstrip() content = re.sub(r"^#+\s*", "", content) content = "## " + content.lstrip() # Formula wrapping if label == "formula": content = content.strip() for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]: if content.startswith(s) and content.endswith(e): content = content[len(s) : -len(e)].strip() break if not content: logger.warning("Skipping formula region with empty content after stripping delimiters") return "" content = "$$\n" + content + "\n$$" # Text formatting if label == "text": if content.startswith("·") or content.startswith("•") or content.startswith("* "): content = "- " + content[1:].lstrip() match = re.match(r"^(\(|\()(\d+|[A-Za-z])(\)|\))(.*)$", content) if match: _, symbol, _, rest = match.groups() content = f"({symbol}) {rest.lstrip()}" match = re.match(r"^(\d+|[A-Za-z])(\.|\)|\))(.*)$", content) if match: symbol, sep, rest = match.groups() sep = ")" if sep == ")" else sep content = f"{symbol}{sep} {rest.lstrip()}" # Single newline → double newline content = re.sub(r"(? List[Dict]: """Merge formula_number region into adjacent formula with \\tag{}.""" if not items: return items merged: List[Dict] = [] skip: set = set() for i, block in enumerate(items): if i in skip: continue native = block.get("native_label", "") # Case 1: formula_number then formula if native == "formula_number": if i + 1 < len(items) and items[i + 1].get("label") == "formula": num_clean = clean_formula_number(block.get("content", "").strip()) formula_content = items[i + 1].get("content", "") merged_block = deepcopy(items[i + 1]) if formula_content.endswith("\n$$"): merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" merged.append(merged_block) skip.add(i + 1) continue # always skip the formula_number block itself # Case 2: formula then formula_number if block.get("label") == "formula": if i + 1 < len(items) and items[i + 1].get("native_label") == "formula_number": num_clean = clean_formula_number(items[i + 1].get("content", "").strip()) formula_content = block.get("content", "") merged_block = deepcopy(block) if formula_content.endswith("\n$$"): merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$" merged.append(merged_block) skip.add(i + 1) continue merged.append(block) for i, block in enumerate(merged): block["index"] = i return merged def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]: """Merge hyphenated text blocks when the combined word is valid (wordfreq).""" if not items or not _WORDFREQ_AVAILABLE: return items merged: List[Dict] = [] skip: set = set() for i, block in enumerate(items): if i in skip: continue if block.get("label") != "text": merged.append(block) continue content = block.get("content", "") if not isinstance(content, str) or not content.rstrip().endswith("-"): merged.append(block) continue content_stripped = content.rstrip() did_merge = False for j in range(i + 1, len(items)): if items[j].get("label") != "text": continue next_content = items[j].get("content", "") if not isinstance(next_content, str): continue next_stripped = next_content.lstrip() if next_stripped and next_stripped[0].islower(): words_before = content_stripped[:-1].split() next_words = next_stripped.split() if words_before and next_words: merged_word = words_before[-1] + next_words[0] if zipf_frequency(merged_word.lower(), "en") >= 2.5: merged_block = deepcopy(block) merged_block["content"] = content_stripped[:-1] + next_content.lstrip() merged.append(merged_block) skip.add(j) did_merge = True break if not did_merge: merged.append(block) for i, block in enumerate(merged): block["index"] = i return merged def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]: """Add missing bullet prefix when a text block is sandwiched between two bullet items.""" if len(items) < 3: return items for i in range(1, len(items) - 1): cur = items[i] prev = items[i - 1] nxt = items[i + 1] if cur.get("native_label") != "text": continue if prev.get("native_label") != "text" or nxt.get("native_label") != "text": continue cur_content = cur.get("content", "") if cur_content.startswith("- "): continue prev_content = prev.get("content", "") nxt_content = nxt.get("content", "") if not (prev_content.startswith("- ") and nxt_content.startswith("- ")): continue cur_bbox = cur.get("bbox_2d", []) prev_bbox = prev.get("bbox_2d", []) nxt_bbox = nxt.get("bbox_2d", []) if not (cur_bbox and prev_bbox and nxt_bbox): continue if ( abs(cur_bbox[0] - prev_bbox[0]) <= left_align_threshold and abs(cur_bbox[0] - nxt_bbox[0]) <= left_align_threshold ): cur["content"] = "- " + cur_content return items