fix: remove padding from GLMOCREndToEndService and clean up ruff violations
- Drop image padding in GLMOCREndToEndService.recognize(); use raw image directly - Fix F821 undefined `padded` references replaced with `image` - Fix F601 duplicate dict key "≠" in converter - Fix F841 unused `image_cls_ids` variable in layout_postprocess - Fix E702 semicolon-separated statements in layout_postprocess - Fix UP031 percent-format replaced with f-string in logging_config - Auto-fix 44 additional ruff violations (import order, UP035/UP045/UP006, F401, F541) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,26 +1,10 @@
|
||||
"""GLM-OCR postprocessing logic adapted for this project.
|
||||
|
||||
Ported from glm-ocr/glmocr/postprocess/result_formatter.py and
|
||||
glm-ocr/glmocr/utils/result_postprocess_utils.py.
|
||||
|
||||
Covers:
|
||||
- Repeated-content / hallucination detection
|
||||
- Per-region content cleaning and formatting (titles, bullets, formulas)
|
||||
- formula_number merging (→ \\tag{})
|
||||
- Hyphenated text-block merging (via wordfreq)
|
||||
- Missing bullet-point detection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from wordfreq import zipf_frequency
|
||||
@@ -29,13 +13,14 @@ try:
|
||||
except ImportError:
|
||||
_WORDFREQ_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# result_postprocess_utils (ported)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> Optional[str]:
|
||||
def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 10) -> str | None:
|
||||
"""Detect and truncate a consecutively-repeated pattern.
|
||||
|
||||
Returns the string with the repeat removed, or None if not found.
|
||||
@@ -49,7 +34,13 @@ def find_consecutive_repeat(s: str, min_unit_len: int = 10, min_repeats: int = 1
|
||||
return None
|
||||
|
||||
pattern = re.compile(
|
||||
r"(.{" + str(min_unit_len) + "," + str(max_unit_len) + r"}?)\1{" + str(min_repeats - 1) + ",}",
|
||||
r"(.{"
|
||||
+ str(min_unit_len)
|
||||
+ ","
|
||||
+ str(max_unit_len)
|
||||
+ r"}?)\1{"
|
||||
+ str(min_repeats - 1)
|
||||
+ ",}",
|
||||
re.DOTALL,
|
||||
)
|
||||
match = pattern.search(s)
|
||||
@@ -83,7 +74,9 @@ def clean_repeated_content(
|
||||
if count >= line_threshold and (count / total_lines) >= 0.8:
|
||||
for i, line in enumerate(lines):
|
||||
if line == common:
|
||||
consecutive = sum(1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common)
|
||||
consecutive = sum(
|
||||
1 for j in range(i, min(i + 3, len(lines))) if lines[j] == common
|
||||
)
|
||||
if consecutive >= 3:
|
||||
original_lines = content.split("\n")
|
||||
non_empty_count = 0
|
||||
@@ -106,7 +99,7 @@ def clean_formula_number(number_content: str) -> str:
|
||||
# Strip display math delimiters
|
||||
for start, end in [("$$", "$$"), (r"\[", r"\]"), ("$", "$"), (r"\(", r"\)")]:
|
||||
if s.startswith(start) and s.endswith(end) and len(s) > len(start) + len(end):
|
||||
s = s[len(start):-len(end)].strip()
|
||||
s = s[len(start) : -len(end)].strip()
|
||||
break
|
||||
# Strip CJK/ASCII parentheses
|
||||
if s.startswith("(") and s.endswith(")"):
|
||||
@@ -121,7 +114,7 @@ def clean_formula_number(number_content: str) -> str:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Label → canonical category mapping (mirrors GLM-OCR label_visualization_mapping)
|
||||
_LABEL_TO_CATEGORY: Dict[str, str] = {
|
||||
_LABEL_TO_CATEGORY: dict[str, str] = {
|
||||
# text
|
||||
"abstract": "text",
|
||||
"algorithm": "text",
|
||||
@@ -157,7 +150,7 @@ class GLMResultFormatter:
|
||||
# Public entry-point
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def process(self, regions: List[Dict[str, Any]]) -> str:
|
||||
def process(self, regions: list[dict[str, Any]]) -> str:
|
||||
"""Run the full postprocessing pipeline and return Markdown.
|
||||
|
||||
Args:
|
||||
@@ -175,7 +168,7 @@ class GLMResultFormatter:
|
||||
items = sorted(deepcopy(regions), key=lambda x: x.get("index", 0))
|
||||
|
||||
# Per-region cleaning + formatting
|
||||
processed: List[Dict] = []
|
||||
processed: list[dict] = []
|
||||
for item in items:
|
||||
item["native_label"] = item.get("native_label", item.get("label", "text"))
|
||||
item["label"] = self._map_label(item.get("label", "text"), item["native_label"])
|
||||
@@ -199,7 +192,7 @@ class GLMResultFormatter:
|
||||
processed = self._format_bullet_points(processed)
|
||||
|
||||
# Assemble Markdown
|
||||
parts: List[str] = []
|
||||
parts: list[str] = []
|
||||
for item in processed:
|
||||
content = item.get("content") or ""
|
||||
if item["label"] == "image":
|
||||
@@ -263,11 +256,15 @@ class GLMResultFormatter:
|
||||
if label == "formula":
|
||||
content = content.strip()
|
||||
for s, e in [("$$", "$$"), (r"\[", r"\]"), (r"\(", r"\)")]:
|
||||
if content.startswith(s) and content.endswith(e):
|
||||
content = content[len(s) : -len(e)].strip()
|
||||
if content.startswith(s):
|
||||
content = content[len(s) :].strip()
|
||||
if content.endswith(e):
|
||||
content = content[: -len(e)].strip()
|
||||
break
|
||||
if not content:
|
||||
logger.warning("Skipping formula region with empty content after stripping delimiters")
|
||||
logger.warning(
|
||||
"Skipping formula region with empty content after stripping delimiters"
|
||||
)
|
||||
return ""
|
||||
content = "$$\n" + content + "\n$$"
|
||||
|
||||
@@ -296,12 +293,12 @@ class GLMResultFormatter:
|
||||
# Structural merges
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _merge_formula_numbers(self, items: List[Dict]) -> List[Dict]:
|
||||
def _merge_formula_numbers(self, items: list[dict]) -> list[dict]:
|
||||
"""Merge formula_number region into adjacent formula with \\tag{}."""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
merged: List[Dict] = []
|
||||
merged: list[dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
@@ -317,7 +314,9 @@ class GLMResultFormatter:
|
||||
formula_content = items[i + 1].get("content", "")
|
||||
merged_block = deepcopy(items[i + 1])
|
||||
if formula_content.endswith("\n$$"):
|
||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
merged_block["content"] = (
|
||||
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
)
|
||||
merged.append(merged_block)
|
||||
skip.add(i + 1)
|
||||
continue # always skip the formula_number block itself
|
||||
@@ -329,7 +328,9 @@ class GLMResultFormatter:
|
||||
formula_content = block.get("content", "")
|
||||
merged_block = deepcopy(block)
|
||||
if formula_content.endswith("\n$$"):
|
||||
merged_block["content"] = formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
merged_block["content"] = (
|
||||
formula_content[:-3] + f" \\tag{{{num_clean}}}\n$$"
|
||||
)
|
||||
merged.append(merged_block)
|
||||
skip.add(i + 1)
|
||||
continue
|
||||
@@ -340,12 +341,12 @@ class GLMResultFormatter:
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _merge_text_blocks(self, items: List[Dict]) -> List[Dict]:
|
||||
def _merge_text_blocks(self, items: list[dict]) -> list[dict]:
|
||||
"""Merge hyphenated text blocks when the combined word is valid (wordfreq)."""
|
||||
if not items or not _WORDFREQ_AVAILABLE:
|
||||
return items
|
||||
|
||||
merged: List[Dict] = []
|
||||
merged: list[dict] = []
|
||||
skip: set = set()
|
||||
|
||||
for i, block in enumerate(items):
|
||||
@@ -389,7 +390,9 @@ class GLMResultFormatter:
|
||||
block["index"] = i
|
||||
return merged
|
||||
|
||||
def _format_bullet_points(self, items: List[Dict], left_align_threshold: float = 10.0) -> List[Dict]:
|
||||
def _format_bullet_points(
|
||||
self, items: list[dict], left_align_threshold: float = 10.0
|
||||
) -> list[dict]:
|
||||
"""Add missing bullet prefix when a text block is sandwiched between two bullet items."""
|
||||
if len(items) < 3:
|
||||
return items
|
||||
|
||||
Reference in New Issue
Block a user