fix: image alpha error
This commit is contained in:
@@ -35,12 +35,10 @@ async def process_image_ocr(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 3. Perform OCR based on layout
|
|
||||||
ocr_result = ocr_service.recognize(image)
|
ocr_result = ocr_service.recognize(image)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=503, detail=str(e))
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
|
|
||||||
# 4. Return response
|
|
||||||
return ImageOCRResponse(
|
return ImageOCRResponse(
|
||||||
latex=ocr_result.get("latex", ""),
|
latex=ocr_result.get("latex", ""),
|
||||||
markdown=ocr_result.get("markdown", ""),
|
markdown=ocr_result.get("markdown", ""),
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ app = FastAPI(
|
|||||||
app.include_router(api_router, prefix=settings.api_prefix)
|
app.include_router(api_router, prefix=settings.api_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
|
|||||||
@@ -25,6 +25,38 @@ class ImageProcessor:
|
|||||||
"""
|
"""
|
||||||
self.padding_ratio = padding_ratio or settings.image_padding_ratio
|
self.padding_ratio = padding_ratio or settings.image_padding_ratio
|
||||||
|
|
||||||
|
def _convert_to_bgr(self, pil_image: Image.Image) -> np.ndarray:
|
||||||
|
"""Convert PIL Image to BGR numpy array, handling alpha channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL Image object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image as numpy array in BGR format.
|
||||||
|
"""
|
||||||
|
# Handle RGBA images (PNG with transparency)
|
||||||
|
if pil_image.mode == "RGBA":
|
||||||
|
# Create white background and paste image on top
|
||||||
|
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||||
|
background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha as mask
|
||||||
|
pil_image = background
|
||||||
|
elif pil_image.mode == "LA":
|
||||||
|
# Grayscale with alpha
|
||||||
|
background = Image.new("L", pil_image.size, 255)
|
||||||
|
background.paste(pil_image, mask=pil_image.split()[1])
|
||||||
|
pil_image = background.convert("RGB")
|
||||||
|
elif pil_image.mode == "P":
|
||||||
|
# Palette mode, may have transparency
|
||||||
|
pil_image = pil_image.convert("RGBA")
|
||||||
|
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||||
|
background.paste(pil_image, mask=pil_image.split()[3])
|
||||||
|
pil_image = background
|
||||||
|
elif pil_image.mode != "RGB":
|
||||||
|
# Convert other modes to RGB
|
||||||
|
pil_image = pil_image.convert("RGB")
|
||||||
|
|
||||||
|
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
def load_image_from_url(self, url: str) -> np.ndarray:
|
def load_image_from_url(self, url: str) -> np.ndarray:
|
||||||
"""Load image from URL.
|
"""Load image from URL.
|
||||||
|
|
||||||
@@ -40,8 +72,8 @@ class ImageProcessor:
|
|||||||
try:
|
try:
|
||||||
with urlopen(url, timeout=30) as response:
|
with urlopen(url, timeout=30) as response:
|
||||||
image_data = response.read()
|
image_data = response.read()
|
||||||
image = Image.open(io.BytesIO(image_data))
|
pil_image = Image.open(io.BytesIO(image_data))
|
||||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
return self._convert_to_bgr(pil_image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to load image from URL: {e}") from e
|
raise ValueError(f"Failed to load image from URL: {e}") from e
|
||||||
|
|
||||||
@@ -63,8 +95,8 @@ class ImageProcessor:
|
|||||||
base64_str = base64_str.split(",", 1)[1]
|
base64_str = base64_str.split(",", 1)[1]
|
||||||
|
|
||||||
image_data = base64.b64decode(base64_str)
|
image_data = base64.b64decode(base64_str)
|
||||||
image = Image.open(io.BytesIO(image_data))
|
pil_image = Image.open(io.BytesIO(image_data))
|
||||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
return self._convert_to_bgr(pil_image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to decode base64 image: {e}") from e
|
raise ValueError(f"Failed to decode base64 image: {e}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -140,18 +140,39 @@ class LayoutDetector:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import cv2
|
import cv2
|
||||||
|
from app.core.config import get_settings
|
||||||
from app.services.image_processor import ImageProcessor
|
from app.services.image_processor import ImageProcessor
|
||||||
|
from app.services.converter import Converter
|
||||||
|
from app.services.ocr_service import OCRService
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Initialize dependencies
|
||||||
layout_detector = LayoutDetector()
|
layout_detector = LayoutDetector()
|
||||||
image_path = "test/timeout.png"
|
image_processor = ImageProcessor(padding_ratio=settings.image_padding_ratio)
|
||||||
|
converter = Converter()
|
||||||
|
|
||||||
|
# Initialize OCR service
|
||||||
|
ocr_service = OCRService(
|
||||||
|
vl_server_url=settings.paddleocr_vl_url,
|
||||||
|
layout_detector=layout_detector,
|
||||||
|
image_processor=image_processor,
|
||||||
|
converter=converter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load test image
|
||||||
|
image_path = "test/complex_formula.png"
|
||||||
image = cv2.imread(image_path)
|
image = cv2.imread(image_path)
|
||||||
image_processor = ImageProcessor(padding_ratio=0.15)
|
|
||||||
image = image_processor.add_padding(image)
|
if image is None:
|
||||||
|
print(f"Failed to load image: {image_path}")
|
||||||
# Save the padded image for debugging
|
else:
|
||||||
cv2.imwrite("debug_padded_image.png", image)
|
print(f"Image loaded: {image.shape}")
|
||||||
|
|
||||||
|
# Run OCR recognition
|
||||||
layout_info = layout_detector.detect(image)
|
result = ocr_service.recognize(image)
|
||||||
print(layout_info)
|
|
||||||
|
print("\n=== OCR Result ===")
|
||||||
|
print(f"Markdown:\n{result['markdown']}")
|
||||||
|
print(f"\nLaTeX:\n{result['latex']}")
|
||||||
|
print(f"\nMathML:\n{result['mathml']}")
|
||||||
@@ -35,6 +35,7 @@ class OCRService:
|
|||||||
self.layout_detector = layout_detector
|
self.layout_detector = layout_detector
|
||||||
self.image_processor = image_processor
|
self.image_processor = image_processor
|
||||||
self.converter = converter
|
self.converter = converter
|
||||||
|
|
||||||
def _get_pipeline(self):
|
def _get_pipeline(self):
|
||||||
"""Get or create PaddleOCR-VL pipeline.
|
"""Get or create PaddleOCR-VL pipeline.
|
||||||
|
|
||||||
@@ -127,15 +128,3 @@ class OCRService:
|
|||||||
return self.recognize_mixed(image)
|
return self.recognize_mixed(image)
|
||||||
else:
|
else:
|
||||||
return self.recognize_formula(image)
|
return self.recognize_formula(image)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import cv2
|
|
||||||
from app.services.image_processor import ImageProcessor
|
|
||||||
from app.services.layout_detector import LayoutDetector
|
|
||||||
image_processor = ImageProcessor(padding_ratio=0.15)
|
|
||||||
layout_detector = LayoutDetector()
|
|
||||||
ocr_service = OCRService(image_processor=image_processor, layout_detector=layout_detector)
|
|
||||||
image = cv2.imread("test/image.png")
|
|
||||||
ocr_result = ocr_service.recognize(image)
|
|
||||||
print(ocr_result)
|
|
||||||
Reference in New Issue
Block a user