fix: image alpha error
This commit is contained in:
@@ -35,12 +35,10 @@ async def process_image_ocr(
|
||||
)
|
||||
|
||||
try:
|
||||
# 3. Perform OCR based on layout
|
||||
ocr_result = ocr_service.recognize(image)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=503, detail=str(e))
|
||||
|
||||
# 4. Return response
|
||||
return ImageOCRResponse(
|
||||
latex=ocr_result.get("latex", ""),
|
||||
markdown=ocr_result.get("markdown", ""),
|
||||
|
||||
@@ -33,6 +33,7 @@ app = FastAPI(
|
||||
app.include_router(api_router, prefix=settings.api_prefix)
|
||||
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
|
||||
@@ -25,6 +25,38 @@ class ImageProcessor:
|
||||
"""
|
||||
self.padding_ratio = padding_ratio or settings.image_padding_ratio
|
||||
|
||||
def _convert_to_bgr(self, pil_image: Image.Image) -> np.ndarray:
|
||||
"""Convert PIL Image to BGR numpy array, handling alpha channel.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image object.
|
||||
|
||||
Returns:
|
||||
Image as numpy array in BGR format.
|
||||
"""
|
||||
# Handle RGBA images (PNG with transparency)
|
||||
if pil_image.mode == "RGBA":
|
||||
# Create white background and paste image on top
|
||||
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha as mask
|
||||
pil_image = background
|
||||
elif pil_image.mode == "LA":
|
||||
# Grayscale with alpha
|
||||
background = Image.new("L", pil_image.size, 255)
|
||||
background.paste(pil_image, mask=pil_image.split()[1])
|
||||
pil_image = background.convert("RGB")
|
||||
elif pil_image.mode == "P":
|
||||
# Palette mode, may have transparency
|
||||
pil_image = pil_image.convert("RGBA")
|
||||
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
background.paste(pil_image, mask=pil_image.split()[3])
|
||||
pil_image = background
|
||||
elif pil_image.mode != "RGB":
|
||||
# Convert other modes to RGB
|
||||
pil_image = pil_image.convert("RGB")
|
||||
|
||||
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
def load_image_from_url(self, url: str) -> np.ndarray:
|
||||
"""Load image from URL.
|
||||
|
||||
@@ -40,8 +72,8 @@ class ImageProcessor:
|
||||
try:
|
||||
with urlopen(url, timeout=30) as response:
|
||||
image_data = response.read()
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
return self._convert_to_bgr(pil_image)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load image from URL: {e}") from e
|
||||
|
||||
@@ -63,8 +95,8 @@ class ImageProcessor:
|
||||
base64_str = base64_str.split(",", 1)[1]
|
||||
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
return self._convert_to_bgr(pil_image)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode base64 image: {e}") from e
|
||||
|
||||
|
||||
@@ -140,18 +140,39 @@ class LayoutDetector:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
from app.core.config import get_settings
|
||||
from app.services.image_processor import ImageProcessor
|
||||
|
||||
from app.services.converter import Converter
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize dependencies
|
||||
layout_detector = LayoutDetector()
|
||||
image_path = "test/timeout.png"
|
||||
|
||||
image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio)
|
||||
converter = Converter()
|
||||
|
||||
# Initialize OCR service
|
||||
ocr_service = OCRService(
|
||||
vl_server_url=settings.paddleocr_vl_url,
|
||||
layout_detector=layout_detector,
|
||||
image_processor=image_processor,
|
||||
converter=converter,
|
||||
)
|
||||
|
||||
# Load test image
|
||||
image_path = "test/complex_formula.png"
|
||||
image = cv2.imread(image_path)
|
||||
image_processor = ImageProcessor(padding_ratio=0.15)
|
||||
image = image_processor.add_padding(image)
|
||||
|
||||
# Save the padded image for debugging
|
||||
cv2.imwrite("debug_padded_image.png", image)
|
||||
|
||||
|
||||
layout_info = layout_detector.detect(image)
|
||||
print(layout_info)
|
||||
|
||||
if image is None:
|
||||
print(f"Failed to load image: {image_path}")
|
||||
else:
|
||||
print(f"Image loaded: {image.shape}")
|
||||
|
||||
# Run OCR recognition
|
||||
result = ocr_service.recognize(image)
|
||||
|
||||
print("\n=== OCR Result ===")
|
||||
print(f"Markdown:\n{result['markdown']}")
|
||||
print(f"\nLaTeX:\n{result['latex']}")
|
||||
print(f"\nMathML:\n{result['mathml']}")
|
||||
@@ -35,6 +35,7 @@ class OCRService:
|
||||
self.layout_detector = layout_detector
|
||||
self.image_processor = image_processor
|
||||
self.converter = converter
|
||||
|
||||
def _get_pipeline(self):
|
||||
"""Get or create PaddleOCR-VL pipeline.
|
||||
|
||||
@@ -127,15 +128,3 @@ class OCRService:
|
||||
return self.recognize_mixed(image)
|
||||
else:
|
||||
return self.recognize_formula(image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
from app.services.image_processor import ImageProcessor
|
||||
from app.services.layout_detector import LayoutDetector
|
||||
image_processor = ImageProcessor(padding_ratio=0.15)
|
||||
layout_detector = LayoutDetector()
|
||||
ocr_service = OCRService(image_processor=image_processor, layout_detector=layout_detector)
|
||||
image = cv2.imread("test/image.png")
|
||||
ocr_result = ocr_service.recognize(image)
|
||||
print(ocr_result)
|
||||
Reference in New Issue
Block a user