Files
doc_processer/test_ocr_number_fix.py

295 lines
8.8 KiB
Python
Raw Normal View History

2026-02-04 16:04:18 +08:00
"""Test OCR number error fixing."""
from app.services.converter import Converter
def test_ocr_number_errors():
"""Test fixing of common OCR number errors."""
print("=" * 80)
print("Testing OCR Number Error Fixes")
print("=" * 80)
converter = Converter()
# Test cases from the error
test_cases = [
{
"name": "Original error case",
"latex": r"\gamma = 2 2. 2, c = 3 0. 4, \phi = 2 5. 4 ^ {\circ}",
"expected_fixes": ["22.2", "30.4", "25.4"],
"should_not_have": ["2 2", "3 0", "2 5"],
},
{
"name": "Simple decimal with space",
"latex": r"x = 3. 14",
"expected_fixes": ["3.14"],
"should_not_have": ["3. 14"],
},
{
"name": "Multiple decimals",
"latex": r"a = 1 2. 5, b = 9. 8 7",
"expected_fixes": ["12.5", "9.87"],
"should_not_have": ["1 2", "9. 8"],
},
{
"name": "Large numbers with spaces",
"latex": r"n = 1 5 0, m = 2 0 0 0",
"expected_fixes": ["150", "2000"],
"should_not_have": ["1 5", "2 0 0"],
},
{
"name": "Don't merge across operators",
"latex": r"2 + 3 = 5",
"expected_fixes": ["2 + 3 = 5"], # Should stay the same
"should_not_have": ["23=5"],
},
]
all_passed = True
for i, test in enumerate(test_cases, 1):
print(f"\nTest {i}: {test['name']}")
print("-" * 80)
print(f"Input: {test['latex']}")
# Apply fix
fixed = converter._fix_ocr_number_errors(test['latex'])
print(f"Fixed: {fixed}")
# Check expected fixes
checks_passed = []
for expected in test['expected_fixes']:
if expected in fixed:
checks_passed.append(f"✓ Contains '{expected}'")
else:
checks_passed.append(f"✗ Missing '{expected}'")
all_passed = False
for should_not in test['should_not_have']:
if should_not not in fixed:
checks_passed.append(f"✓ Removed '{should_not}'")
else:
checks_passed.append(f"✗ Still has '{should_not}'")
all_passed = False
for check in checks_passed:
print(f" {check}")
return all_passed
def test_mathml_quality():
"""Test that fixed LaTeX produces better MathML."""
print("\n" + "=" * 80)
print("Testing MathML Quality After OCR Fix")
print("=" * 80)
converter = Converter()
# The problematic LaTeX from the error
latex = r"\gamma = 2 2. 2, c = 3 0. 4, \phi = 2 5. 4 ^ {\circ}"
print(f"\nOriginal LaTeX: {latex}")
# Convert to MathML
result = converter.convert_to_formats(f"${latex}$")
mathml = result.mathml
print(f"\nMathML length: {len(mathml)} chars")
# Check quality indicators
print("\nQuality checks:")
print("-" * 80)
checks = {
"No separate digits for decimals": "<mn>22.2</mn>" in mathml or "22.2" in mathml,
"No dot as identifier": "<mi>.</mi>" not in mathml,
"Properly formatted numbers": "<mn>30.4</mn>" in mathml or "30.4" in mathml,
"Has namespace": 'xmlns=' in mathml,
"Display block": 'display="block"' in mathml,
}
all_passed = True
for check, passed in checks.items():
status = "" if passed else ""
print(f"{status} {check}")
if not passed:
all_passed = False
# Show a preview
print("\n" + "-" * 80)
print("MathML preview:")
print("-" * 80)
print(mathml[:400])
if len(mathml) > 400:
print("...")
return all_passed
def test_edge_cases():
"""Test edge cases for OCR number fixing."""
print("\n" + "=" * 80)
print("Testing Edge Cases")
print("=" * 80)
converter = Converter()
test_cases = [
{
"name": "Should NOT merge: arithmetic",
"input": r"2 + 3 = 5",
"should_stay": "2 + 3 = 5",
},
{
"name": "Should NOT merge: multiplication",
"input": r"2 \times 3",
"should_stay": r"2 \times 3",
},
{
"name": "Should merge: decimal at end",
"input": r"x = 1 2. 5",
"should_become": "12.5",
},
{
"name": "Should merge: multiple spaces",
"input": r"n = 1 2 . 3 4",
"should_have": "12.34",
},
{
"name": "Complex: mixed scenarios",
"input": r"a = 1 2. 3 + 4 5. 6 - 7",
"should_have": ["12.3", "45.6", "- 7"],
},
]
all_passed = True
for test in test_cases:
print(f"\n{test['name']}")
print(f" Input: {test['input']}")
fixed = converter._fix_ocr_number_errors(test['input'])
print(f" Output: {fixed}")
if 'should_stay' in test:
if fixed == test['should_stay']:
print(f" ✓ Correctly unchanged")
else:
print(f" ✗ Should stay '{test['should_stay']}' but got '{fixed}'")
all_passed = False
if 'should_become' in test:
if test['should_become'] in fixed:
print(f" ✓ Contains '{test['should_become']}'")
else:
print(f" ✗ Should contain '{test['should_become']}'")
all_passed = False
if 'should_have' in test:
for expected in test['should_have']:
if expected in fixed:
print(f" ✓ Contains '{expected}'")
else:
print(f" ✗ Should contain '{expected}'")
all_passed = False
return all_passed
def compare_before_after():
"""Compare MathML before and after OCR fix."""
print("\n" + "=" * 80)
print("Before/After Comparison")
print("=" * 80)
converter = Converter()
# Simulate OCR error
ocr_latex = r"\gamma = 2 2. 2, c = 3 0. 4"
correct_latex = r"\gamma = 22.2, c = 30.4"
print(f"\nOCR LaTeX: {ocr_latex}")
print(f"Correct LaTeX: {correct_latex}")
# Convert both
ocr_result = converter.convert_to_formats(f"${ocr_latex}$")
correct_result = converter.convert_to_formats(f"${correct_latex}$")
print("\n" + "-" * 80)
print("MathML comparison:")
print("-" * 80)
# Check if they produce similar quality output
ocr_has_decimal = "22.2" in ocr_result.mathml
correct_has_decimal = "22.2" in correct_result.mathml
ocr_has_dot_error = "<mi>.</mi>" in ocr_result.mathml
correct_has_dot_error = "<mi>.</mi>" in correct_result.mathml
print(f"OCR output has proper decimals: {'' if ocr_has_decimal else ''}")
print(f"Correct output has proper decimals: {'' if correct_has_decimal else ''}")
print(f"OCR output has dot errors: {'✗ Yes' if ocr_has_dot_error else '✓ No'}")
print(f"Correct output has dot errors: {'✗ Yes' if correct_has_dot_error else '✓ No'}")
if ocr_has_decimal and not ocr_has_dot_error:
print("\n✓ OCR fix is working! Output quality matches correct input.")
return True
else:
print("\n✗ OCR fix may need improvement.")
return False
if __name__ == "__main__":
print("OCR Number Error Fix Test Suite\n")
try:
test1 = test_ocr_number_errors()
test2 = test_mathml_quality()
test3 = test_edge_cases()
test4 = compare_before_after()
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
results = [
("OCR error fixes", test1),
("MathML quality", test2),
("Edge cases", test3),
("Before/after comparison", test4),
]
for name, passed in results:
status = "✓ PASS" if passed else "✗ FAIL"
print(f"{status}: {name}")
all_passed = all(r[1] for r in results)
print("\n" + "-" * 80)
if all_passed:
print("✓✓✓ ALL TESTS PASSED ✓✓✓")
print("\nOCR number errors are being fixed automatically!")
print("Examples:")
print("'2 2. 2''22.2'")
print("'3 0. 4''30.4'")
print("'1 5 0''150'")
else:
print("✗✗✗ SOME TESTS FAILED ✗✗✗")
print("=" * 80)
except KeyboardInterrupt:
print("\n\nTests interrupted")
except Exception as e:
print(f"\n\nTest error: {e}")
import traceback
traceback.print_exc()