From 3870c108b26f4afdb5f2232fb89400ba2d3fc9f4 Mon Sep 17 00:00:00 2001 From: yogeliu Date: Thu, 1 Jan 2026 23:38:52 +0800 Subject: [PATCH] fix: image alpha error --- app/api/v1/endpoints/image.py | 2 -- app/main.py | 1 + app/services/image_processor.py | 40 ++++++++++++++++++++++++++--- app/services/layout_detector.py | 45 ++++++++++++++++++++++++--------- app/services/ocr_service.py | 13 +--------- 5 files changed, 71 insertions(+), 30 deletions(-) diff --git a/app/api/v1/endpoints/image.py b/app/api/v1/endpoints/image.py index 635ebf7..2c2e141 100644 --- a/app/api/v1/endpoints/image.py +++ b/app/api/v1/endpoints/image.py @@ -35,12 +35,10 @@ async def process_image_ocr( ) try: - # 3. Perform OCR based on layout ocr_result = ocr_service.recognize(image) except RuntimeError as e: raise HTTPException(status_code=503, detail=str(e)) - # 4. Return response return ImageOCRResponse( latex=ocr_result.get("latex", ""), markdown=ocr_result.get("markdown", ""), diff --git a/app/main.py b/app/main.py index 88d9fe2..d879399 100644 --- a/app/main.py +++ b/app/main.py @@ -33,6 +33,7 @@ app = FastAPI( app.include_router(api_router, prefix=settings.api_prefix) + @app.get("/health") async def health_check(): """Health check endpoint.""" diff --git a/app/services/image_processor.py b/app/services/image_processor.py index d7abed1..b57dff6 100644 --- a/app/services/image_processor.py +++ b/app/services/image_processor.py @@ -25,6 +25,38 @@ class ImageProcessor: """ self.padding_ratio = padding_ratio or settings.image_padding_ratio + def _convert_to_bgr(self, pil_image: Image.Image) -> np.ndarray: + """Convert PIL Image to BGR numpy array, handling alpha channel. + + Args: + pil_image: PIL Image object. + + Returns: + Image as numpy array in BGR format. + """ + # Handle RGBA images (PNG with transparency) + if pil_image.mode == "RGBA": + # Create white background and paste image on top + background = Image.new("RGB", pil_image.size, (255, 255, 255)) + background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha as mask + pil_image = background + elif pil_image.mode == "LA": + # Grayscale with alpha + background = Image.new("L", pil_image.size, 255) + background.paste(pil_image, mask=pil_image.split()[1]) + pil_image = background.convert("RGB") + elif pil_image.mode == "P": + # Palette mode, may have transparency + pil_image = pil_image.convert("RGBA") + background = Image.new("RGB", pil_image.size, (255, 255, 255)) + background.paste(pil_image, mask=pil_image.split()[3]) + pil_image = background + elif pil_image.mode != "RGB": + # Convert other modes to RGB + pil_image = pil_image.convert("RGB") + + return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + def load_image_from_url(self, url: str) -> np.ndarray: """Load image from URL. @@ -40,8 +72,8 @@ class ImageProcessor: try: with urlopen(url, timeout=30) as response: image_data = response.read() - image = Image.open(io.BytesIO(image_data)) - return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + pil_image = Image.open(io.BytesIO(image_data)) + return self._convert_to_bgr(pil_image) except Exception as e: raise ValueError(f"Failed to load image from URL: {e}") from e @@ -63,8 +95,8 @@ class ImageProcessor: base64_str = base64_str.split(",", 1)[1] image_data = base64.b64decode(base64_str) - image = Image.open(io.BytesIO(image_data)) - return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + pil_image = Image.open(io.BytesIO(image_data)) + return self._convert_to_bgr(pil_image) except Exception as e: raise ValueError(f"Failed to decode base64 image: {e}") from e diff --git a/app/services/layout_detector.py b/app/services/layout_detector.py index 3cd8446..8c3756b 100644 --- a/app/services/layout_detector.py +++ b/app/services/layout_detector.py @@ -140,18 +140,39 @@ class LayoutDetector: if __name__ == "__main__": import cv2 + from app.core.config import get_settings from app.services.image_processor import ImageProcessor - + from app.services.converter import Converter + from app.services.ocr_service import OCRService + + settings = get_settings() + + # Initialize dependencies layout_detector = LayoutDetector() - image_path = "test/timeout.png" - + image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio) + converter = Converter() + + # Initialize OCR service + ocr_service = OCRService( + vl_server_url=settings.paddleocr_vl_url, + layout_detector=layout_detector, + image_processor=image_processor, + converter=converter, + ) + + # Load test image + image_path = "test/complex_formula.png" image = cv2.imread(image_path) - image_processor = ImageProcessor(padding_ratio=0.15) - image = image_processor.add_padding(image) - - # Save the padded image for debugging - cv2.imwrite("debug_padded_image.png", image) - - - layout_info = layout_detector.detect(image) - print(layout_info) \ No newline at end of file + + if image is None: + print(f"Failed to load image: {image_path}") + else: + print(f"Image loaded: {image.shape}") + + # Run OCR recognition + result = ocr_service.recognize(image) + + print("\n=== OCR Result ===") + print(f"Markdown:\n{result['markdown']}") + print(f"\nLaTeX:\n{result['latex']}") + print(f"\nMathML:\n{result['mathml']}") \ No newline at end of file diff --git a/app/services/ocr_service.py b/app/services/ocr_service.py index 5b65798..741290b 100644 --- a/app/services/ocr_service.py +++ b/app/services/ocr_service.py @@ -35,6 +35,7 @@ class OCRService: self.layout_detector = layout_detector self.image_processor = image_processor self.converter = converter + def _get_pipeline(self): """Get or create PaddleOCR-VL pipeline. @@ -127,15 +128,3 @@ class OCRService: return self.recognize_mixed(image) else: return self.recognize_formula(image) - - -if __name__ == "__main__": - import cv2 - from app.services.image_processor import ImageProcessor - from app.services.layout_detector import LayoutDetector - image_processor = ImageProcessor(padding_ratio=0.15) - layout_detector = LayoutDetector() - ocr_service = OCRService(image_processor=image_processor, layout_detector=layout_detector) - image = cv2.imread("test/image.png") - ocr_result = ocr_service.recognize(image) - print(ocr_result) \ No newline at end of file