fix: image alpha error

This commit is contained in:
2026-01-01 23:38:52 +08:00
parent 35928c2484
commit 3870c108b2
5 changed files with 71 additions and 30 deletions

View File

@@ -35,12 +35,10 @@ async def process_image_ocr(
) )
try: try:
# 3. Perform OCR based on layout
ocr_result = ocr_service.recognize(image) ocr_result = ocr_service.recognize(image)
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e)) raise HTTPException(status_code=503, detail=str(e))
# 4. Return response
return ImageOCRResponse( return ImageOCRResponse(
latex=ocr_result.get("latex", ""), latex=ocr_result.get("latex", ""),
markdown=ocr_result.get("markdown", ""), markdown=ocr_result.get("markdown", ""),

View File

@@ -33,6 +33,7 @@ app = FastAPI(
app.include_router(api_router, prefix=settings.api_prefix) app.include_router(api_router, prefix=settings.api_prefix)
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint.""" """Health check endpoint."""

View File

@@ -25,6 +25,38 @@ class ImageProcessor:
""" """
self.padding_ratio = padding_ratio or settings.image_padding_ratio self.padding_ratio = padding_ratio or settings.image_padding_ratio
def _convert_to_bgr(self, pil_image: Image.Image) -> np.ndarray:
"""Convert PIL Image to BGR numpy array, handling alpha channel.
Args:
pil_image: PIL Image object.
Returns:
Image as numpy array in BGR format.
"""
# Handle RGBA images (PNG with transparency)
if pil_image.mode == "RGBA":
# Create white background and paste image on top
background = Image.new("RGB", pil_image.size, (255, 255, 255))
background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha as mask
pil_image = background
elif pil_image.mode == "LA":
# Grayscale with alpha
background = Image.new("L", pil_image.size, 255)
background.paste(pil_image, mask=pil_image.split()[1])
pil_image = background.convert("RGB")
elif pil_image.mode == "P":
# Palette mode, may have transparency
pil_image = pil_image.convert("RGBA")
background = Image.new("RGB", pil_image.size, (255, 255, 255))
background.paste(pil_image, mask=pil_image.split()[3])
pil_image = background
elif pil_image.mode != "RGB":
# Convert other modes to RGB
pil_image = pil_image.convert("RGB")
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
def load_image_from_url(self, url: str) -> np.ndarray: def load_image_from_url(self, url: str) -> np.ndarray:
"""Load image from URL. """Load image from URL.
@@ -40,8 +72,8 @@ class ImageProcessor:
try: try:
with urlopen(url, timeout=30) as response: with urlopen(url, timeout=30) as response:
image_data = response.read() image_data = response.read()
image = Image.open(io.BytesIO(image_data)) pil_image = Image.open(io.BytesIO(image_data))
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) return self._convert_to_bgr(pil_image)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to load image from URL: {e}") from e raise ValueError(f"Failed to load image from URL: {e}") from e
@@ -63,8 +95,8 @@ class ImageProcessor:
base64_str = base64_str.split(",", 1)[1] base64_str = base64_str.split(",", 1)[1]
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data)) pil_image = Image.open(io.BytesIO(image_data))
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) return self._convert_to_bgr(pil_image)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to decode base64 image: {e}") from e raise ValueError(f"Failed to decode base64 image: {e}") from e

View File

@@ -140,18 +140,39 @@ class LayoutDetector:
if __name__ == "__main__": if __name__ == "__main__":
import cv2 import cv2
from app.core.config import get_settings
from app.services.image_processor import ImageProcessor from app.services.image_processor import ImageProcessor
from app.services.converter import Converter
from app.services.ocr_service import OCRService
settings = get_settings()
# Initialize dependencies
layout_detector = LayoutDetector() layout_detector = LayoutDetector()
image_path = "test/timeout.png" image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio)
converter = Converter()
# Initialize OCR service
ocr_service = OCRService(
vl_server_url=settings.paddleocr_vl_url,
layout_detector=layout_detector,
image_processor=image_processor,
converter=converter,
)
# Load test image
image_path = "test/complex_formula.png"
image = cv2.imread(image_path) image = cv2.imread(image_path)
image_processor = ImageProcessor(padding_ratio=0.15)
image = image_processor.add_padding(image) if image is None:
print(f"Failed to load image: {image_path}")
# Save the padded image for debugging else:
cv2.imwrite("debug_padded_image.png", image) print(f"Image loaded: {image.shape}")
# Run OCR recognition
layout_info = layout_detector.detect(image) result = ocr_service.recognize(image)
print(layout_info)
print("\n=== OCR Result ===")
print(f"Markdown:\n{result['markdown']}")
print(f"\nLaTeX:\n{result['latex']}")
print(f"\nMathML:\n{result['mathml']}")

View File

@@ -35,6 +35,7 @@ class OCRService:
self.layout_detector = layout_detector self.layout_detector = layout_detector
self.image_processor = image_processor self.image_processor = image_processor
self.converter = converter self.converter = converter
def _get_pipeline(self): def _get_pipeline(self):
"""Get or create PaddleOCR-VL pipeline. """Get or create PaddleOCR-VL pipeline.
@@ -127,15 +128,3 @@ class OCRService:
return self.recognize_mixed(image) return self.recognize_mixed(image)
else: else:
return self.recognize_formula(image) return self.recognize_formula(image)
if __name__ == "__main__":
import cv2
from app.services.image_processor import ImageProcessor
from app.services.layout_detector import LayoutDetector
image_processor = ImageProcessor(padding_ratio=0.15)
layout_detector = LayoutDetector()
ocr_service = OCRService(image_processor=image_processor, layout_detector=layout_detector)
image = cv2.imread("test/image.png")
ocr_result = ocr_service.recognize(image)
print(ocr_result)