266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
"""Test OCR number error fixing in the complete pipeline."""
|
|
|
|
from app.services.ocr_service import _postprocess_markdown
|
|
|
|
|
|
def test_ocr_postprocessing():
|
|
"""Test that OCR postprocessing fixes number errors."""
|
|
|
|
print("=" * 80)
|
|
print("Testing OCR Postprocessing Pipeline")
|
|
print("=" * 80)
|
|
|
|
# Simulate OCR output with common errors
|
|
test_cases = [
|
|
{
|
|
"name": "Inline formula with decimal errors",
|
|
"input": r"The value is $\gamma = 2 2. 2$ and $c = 3 0. 4$.",
|
|
"should_have": ["22.2", "30.4"],
|
|
"should_not_have": ["2 2", "3 0"],
|
|
},
|
|
{
|
|
"name": "Display formula with decimal errors",
|
|
"input": r"$$\phi = 2 5. 4 ^ {\circ}$$",
|
|
"should_have": ["25.4"],
|
|
"should_not_have": ["2 5"],
|
|
},
|
|
{
|
|
"name": "Multiple formulas",
|
|
"input": r"$a = 1 2. 5$, $b = 9. 8 7$, and $c = 1 5 0$",
|
|
"should_have": ["12.5", "9.87", "150"],
|
|
"should_not_have": ["1 2", "9. 8", "1 5"],
|
|
},
|
|
{
|
|
"name": "Mixed content (text + formulas)",
|
|
"input": r"The equation $x = 3. 14$ is approximately pi. Then $y = 2 7. 3$.",
|
|
"should_have": ["3.14", "27.3"],
|
|
"should_not_have": ["3. 14", "2 7"],
|
|
},
|
|
{
|
|
"name": "Normal arithmetic (should not be affected)",
|
|
"input": r"$2 + 3 = 5$ and $10 - 7 = 3$",
|
|
"should_stay": True,
|
|
},
|
|
]
|
|
|
|
all_passed = True
|
|
|
|
for i, test in enumerate(test_cases, 1):
|
|
print(f"\nTest {i}: {test['name']}")
|
|
print("-" * 80)
|
|
print(f"Input: {test['input']}")
|
|
|
|
# Apply postprocessing
|
|
output = _postprocess_markdown(test['input'])
|
|
print(f"Output: {output}")
|
|
|
|
# Check results
|
|
if 'should_have' in test:
|
|
for expected in test['should_have']:
|
|
if expected in output:
|
|
print(f" ✓ Contains '{expected}'")
|
|
else:
|
|
print(f" ✗ Missing '{expected}'")
|
|
all_passed = False
|
|
|
|
if 'should_not_have' in test:
|
|
for unexpected in test['should_not_have']:
|
|
if unexpected not in output:
|
|
print(f" ✓ Removed '{unexpected}'")
|
|
else:
|
|
print(f" ✗ Still has '{unexpected}'")
|
|
all_passed = False
|
|
|
|
if test.get('should_stay'):
|
|
if test['input'] == output:
|
|
print(f" ✓ Correctly unchanged")
|
|
else:
|
|
print(f" ✗ Should not change but did")
|
|
all_passed = False
|
|
|
|
return all_passed
|
|
|
|
|
|
def test_real_world_case():
|
|
"""Test the exact case from the error report."""
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Testing Real-World Error Case")
|
|
print("=" * 80)
|
|
|
|
# The exact input from the error report
|
|
ocr_output = r"$$\gamma = 2 2. 2, c = 3 0. 4, \phi = 2 5. 4 ^ {\circ}$$"
|
|
|
|
print(f"\nOCR Output (with errors):")
|
|
print(f" {ocr_output}")
|
|
|
|
# Apply postprocessing
|
|
fixed = _postprocess_markdown(ocr_output)
|
|
|
|
print(f"\nAfter Postprocessing:")
|
|
print(f" {fixed}")
|
|
|
|
# Check if fixed
|
|
checks = {
|
|
"Has 22.2": "22.2" in fixed,
|
|
"Has 30.4": "30.4" in fixed,
|
|
"Has 25.4": "25.4" in fixed,
|
|
"No '2 2'": "2 2" not in fixed,
|
|
"No '3 0'": "3 0" not in fixed,
|
|
"No '2 5'": "2 5" not in fixed,
|
|
}
|
|
|
|
print("\nQuality Checks:")
|
|
print("-" * 80)
|
|
|
|
all_passed = True
|
|
for check, passed in checks.items():
|
|
status = "✓" if passed else "✗"
|
|
print(f"{status} {check}")
|
|
if not passed:
|
|
all_passed = False
|
|
|
|
if all_passed:
|
|
print("\n✓ Real-world case fixed successfully!")
|
|
else:
|
|
print("\n✗ Real-world case still has issues")
|
|
|
|
return all_passed
|
|
|
|
|
|
def test_edge_cases():
|
|
"""Test edge cases to ensure we don't break valid formulas."""
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Testing Edge Cases")
|
|
print("=" * 80)
|
|
|
|
test_cases = [
|
|
{
|
|
"name": "Arithmetic operations",
|
|
"input": r"$2 + 3 = 5$ and $10 - 7 = 3$",
|
|
"should_stay": True,
|
|
},
|
|
{
|
|
"name": "Multiplication",
|
|
"input": r"$2 \times 3 = 6$",
|
|
"should_stay": True,
|
|
},
|
|
{
|
|
"name": "Exponents",
|
|
"input": r"$x ^ 2 + y ^ 2 = r ^ 2$",
|
|
"should_stay": True,
|
|
},
|
|
{
|
|
"name": "Fractions",
|
|
"input": r"$\frac{1}{2} + \frac{3}{4}$",
|
|
"should_stay": True,
|
|
},
|
|
{
|
|
"name": "Subscripts",
|
|
"input": r"$x _ 1 + x _ 2$",
|
|
"should_stay": True,
|
|
},
|
|
]
|
|
|
|
all_passed = True
|
|
|
|
for test in test_cases:
|
|
print(f"\n{test['name']}")
|
|
print(f" Input: {test['input']}")
|
|
|
|
output = _postprocess_markdown(test['input'])
|
|
print(f" Output: {output}")
|
|
|
|
if test.get('should_stay'):
|
|
# For these cases, we allow some whitespace changes but structure should stay
|
|
if output.replace(" ", "") == test['input'].replace(" ", ""):
|
|
print(f" ✓ Structure preserved")
|
|
else:
|
|
print(f" ✗ Structure changed unexpectedly")
|
|
all_passed = False
|
|
|
|
return all_passed
|
|
|
|
|
|
def test_performance():
|
|
"""Test performance with large content."""
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Testing Performance")
|
|
print("=" * 80)
|
|
|
|
# Create a large markdown with many formulas
|
|
large_content = ""
|
|
for i in range(100):
|
|
large_content += f"Formula {i}: $x = {i} {i}. {i}$ and $y = {i*2} {i*2}. {i*2}$\n"
|
|
|
|
print(f"\nContent size: {len(large_content)} characters")
|
|
print(f"Number of formulas: ~200")
|
|
|
|
import time
|
|
start = time.time()
|
|
output = _postprocess_markdown(large_content)
|
|
elapsed = time.time() - start
|
|
|
|
print(f"Processing time: {elapsed*1000:.2f}ms")
|
|
|
|
if elapsed < 1.0:
|
|
print("✓ Performance is acceptable (< 1s)")
|
|
return True
|
|
else:
|
|
print("✗ Performance may need optimization")
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("OCR Pipeline Integration Test Suite\n")
|
|
|
|
try:
|
|
test1 = test_ocr_postprocessing()
|
|
test2 = test_real_world_case()
|
|
test3 = test_edge_cases()
|
|
test4 = test_performance()
|
|
|
|
print("\n" + "=" * 80)
|
|
print("SUMMARY")
|
|
print("=" * 80)
|
|
|
|
results = [
|
|
("OCR postprocessing", test1),
|
|
("Real-world case", test2),
|
|
("Edge cases", test3),
|
|
("Performance", test4),
|
|
]
|
|
|
|
for name, passed in results:
|
|
status = "✓ PASS" if passed else "✗ FAIL"
|
|
print(f"{status}: {name}")
|
|
|
|
all_passed = all(r[1] for r in results)
|
|
|
|
print("\n" + "-" * 80)
|
|
|
|
if all_passed:
|
|
print("✓✓✓ ALL TESTS PASSED ✓✓✓")
|
|
print("\nOCR number error fixing is integrated into the pipeline!")
|
|
print("\nFlow:")
|
|
print(" 1. OCR recognizes image → produces Markdown with LaTeX")
|
|
print(" 2. _postprocess_markdown() fixes number errors")
|
|
print(" 3. Clean LaTeX is used for all conversions")
|
|
print("\nBenefits:")
|
|
print(" • Fixed once at the source")
|
|
print(" • All output formats benefit (MathML, MML, OMML)")
|
|
print(" • Better performance (no repeated fixes)")
|
|
else:
|
|
print("✗✗✗ SOME TESTS FAILED ✗✗✗")
|
|
|
|
print("=" * 80)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n\nTests interrupted")
|
|
except Exception as e:
|
|
print(f"\n\nTest error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|