[refactor] Init
This commit is contained in:
26
texteller/utils/__init__.py
Normal file
26
texteller/utils/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from .device import get_device, cuda_available, mps_available, str2device
|
||||
from .image import readimgs, transform
|
||||
from .latex import change_all, remove_style, add_newlines
|
||||
from .path import mkdir, resolve_path
|
||||
from .misc import lines_dedent
|
||||
from .bbox import mask_img, bbox_merge, split_conflict, slice_from_image, draw_bboxes
|
||||
|
||||
__all__ = [
|
||||
"get_device",
|
||||
"cuda_available",
|
||||
"mps_available",
|
||||
"str2device",
|
||||
"readimgs",
|
||||
"transform",
|
||||
"change_all",
|
||||
"remove_style",
|
||||
"add_newlines",
|
||||
"mkdir",
|
||||
"resolve_path",
|
||||
"lines_dedent",
|
||||
"mask_img",
|
||||
"bbox_merge",
|
||||
"split_conflict",
|
||||
"slice_from_image",
|
||||
"draw_bboxes",
|
||||
]
|
||||
142
texteller/utils/bbox.py
Normal file
142
texteller/utils/bbox.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import heapq
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from texteller.types import Bbox
|
||||
|
||||
_MAXV = 999999999
|
||||
|
||||
|
||||
def mask_img(img, bboxes: list[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
||||
mask_img = img.copy()
|
||||
for bbox in bboxes:
|
||||
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
|
||||
return mask_img
|
||||
|
||||
|
||||
def bbox_merge(sorted_bboxes: list[Bbox]) -> list[Bbox]:
|
||||
if len(sorted_bboxes) == 0:
|
||||
return []
|
||||
bboxes = sorted_bboxes.copy()
|
||||
guard = Bbox(_MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
||||
bboxes.append(guard)
|
||||
res = []
|
||||
prev = bboxes[0]
|
||||
for curr in bboxes:
|
||||
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
||||
res.append(prev)
|
||||
prev = curr
|
||||
else:
|
||||
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
||||
return res
|
||||
|
||||
|
||||
def split_conflict(ocr_bboxes: list[Bbox], latex_bboxes: list[Bbox]) -> list[Bbox]:
|
||||
if latex_bboxes == []:
|
||||
return ocr_bboxes
|
||||
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
||||
return ocr_bboxes
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
|
||||
assert len(bboxes) > 1
|
||||
|
||||
heapq.heapify(bboxes)
|
||||
res = []
|
||||
candidate = heapq.heappop(bboxes)
|
||||
curr = heapq.heappop(bboxes)
|
||||
idx = 0
|
||||
while len(bboxes) > 0:
|
||||
idx += 1
|
||||
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||
|
||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.ur_point.x < curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
if candidate.label == "text" and curr.label == "text":
|
||||
candidate.w = curr.ur_point.x - candidate.p.x
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.label != curr.label:
|
||||
if candidate.label == "text":
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
curr.w = curr.ur_point.x - candidate.ur_point.x
|
||||
curr.p.x = candidate.ur_point.x
|
||||
heapq.heappush(bboxes, curr)
|
||||
curr = heapq.heappop(bboxes)
|
||||
|
||||
elif candidate.ur_point.x >= curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
|
||||
if candidate.label == "text":
|
||||
assert curr.label != "text"
|
||||
heapq.heappush(
|
||||
bboxes,
|
||||
Bbox(
|
||||
curr.ur_point.x,
|
||||
candidate.p.y,
|
||||
candidate.h,
|
||||
candidate.ur_point.x - curr.ur_point.x,
|
||||
label="text",
|
||||
confidence=candidate.confidence,
|
||||
content=None,
|
||||
),
|
||||
)
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert curr.label == "text"
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert False
|
||||
res.append(candidate)
|
||||
res.append(curr)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def slice_from_image(img: np.ndarray, ocr_bboxes: list[Bbox]) -> list[np.ndarray]:
|
||||
sliced_imgs = []
|
||||
for bbox in ocr_bboxes:
|
||||
x, y = int(bbox.p.x), int(bbox.p.y)
|
||||
w, h = int(bbox.w), int(bbox.h)
|
||||
sliced_img = img[y : y + h, x : x + w]
|
||||
sliced_imgs.append(sliced_img)
|
||||
return sliced_imgs
|
||||
|
||||
|
||||
def draw_bboxes(img: Image.Image, bboxes: list[Bbox], name="annotated_image.png"):
|
||||
curr_work_dir = Path(os.getcwd())
|
||||
log_dir = curr_work_dir / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
drawer = ImageDraw.Draw(img)
|
||||
for bbox in bboxes:
|
||||
# Calculate the coordinates for the rectangle to be drawn
|
||||
left = bbox.p.x
|
||||
top = bbox.p.y
|
||||
right = bbox.p.x + bbox.w
|
||||
bottom = bbox.p.y + bbox.h
|
||||
|
||||
# Draw the rectangle on the image
|
||||
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
||||
|
||||
# Optionally, add text label if it exists
|
||||
if bbox.label:
|
||||
drawer.text((left, top), bbox.label, fill="blue")
|
||||
|
||||
if bbox.content:
|
||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||
|
||||
# Save the image with drawn rectangles
|
||||
img.save(log_dir / name)
|
||||
41
texteller/utils/device.py
Normal file
41
texteller/utils/device.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device:
|
||||
if device_str == "cpu":
|
||||
return torch.device("cpu")
|
||||
elif device_str == "cuda":
|
||||
return torch.device("cuda")
|
||||
elif device_str == "mps":
|
||||
return torch.device("mps")
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device_str}")
|
||||
|
||||
|
||||
def get_device(device_index: int = None) -> torch.device:
|
||||
"""
|
||||
Automatically detect the best available device for inference.
|
||||
|
||||
Args:
|
||||
device_index: The index of GPU device to use if multiple are available.
|
||||
Defaults to None, which uses the first available GPU.
|
||||
|
||||
Returns:
|
||||
torch.device: Selected device for model inference.
|
||||
"""
|
||||
if cuda_available():
|
||||
return str2device("cuda")
|
||||
elif mps_available():
|
||||
return str2device("mps")
|
||||
else:
|
||||
return str2device("cpu")
|
||||
|
||||
|
||||
def cuda_available() -> bool:
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
def mps_available() -> bool:
|
||||
return torch.backends.mps.is_available()
|
||||
121
texteller/utils/image.py
Normal file
121
texteller/utils/image.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from collections import Counter
|
||||
from typing import List, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from texteller.constants import (
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
IMAGE_MEAN,
|
||||
IMAGE_STD,
|
||||
)
|
||||
from texteller.logger import get_logger
|
||||
|
||||
|
||||
_logger = get_logger()
|
||||
|
||||
|
||||
def readimgs(image_paths: list[str]) -> list[np.ndarray]:
|
||||
"""
|
||||
Read and preprocess a list of images from their file paths.
|
||||
|
||||
This function reads each image from the provided paths, handles different
|
||||
bit depths (converting 16-bit to 8-bit if necessary), and normalizes color
|
||||
channels to RGB format regardless of the original color space (BGR, BGRA,
|
||||
or grayscale).
|
||||
|
||||
Args:
|
||||
image_paths (list[str]): A list of file paths to the images to be read.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray]: A list of NumPy arrays containing the preprocessed images
|
||||
in RGB format. Images that could not be read are skipped.
|
||||
"""
|
||||
processed_images = []
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
raise ValueError(f"Image at {path} could not be read.")
|
||||
if image.dtype == np.uint16:
|
||||
_logger.warning(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray) -> np.ndarray:
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold = 15
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||
general_transform_pipeline = v2.Compose(
|
||||
[
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True),
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True,
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
]
|
||||
)
|
||||
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
images = [
|
||||
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
||||
]
|
||||
images = [trim_white_border(image) for image in images]
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
|
||||
return images
|
||||
128
texteller/utils/latex.py
Normal file
128
texteller/utils/latex.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import re
|
||||
|
||||
|
||||
def _change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
result = ""
|
||||
i = 0
|
||||
n = len(input_str)
|
||||
|
||||
while i < n:
|
||||
if input_str[i : i + len(old_inst)] == old_inst:
|
||||
# check if the old_inst is followed by old_surr_l
|
||||
start = i + len(old_inst)
|
||||
else:
|
||||
result += input_str[i]
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if start < n and input_str[start] == old_surr_l:
|
||||
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
||||
count = 1
|
||||
j = start + 1
|
||||
escaped = False
|
||||
while j < n and count > 0:
|
||||
if input_str[j] == '\\' and not escaped:
|
||||
escaped = True
|
||||
j += 1
|
||||
continue
|
||||
if input_str[j] == old_surr_r and not escaped:
|
||||
count -= 1
|
||||
if count == 0:
|
||||
break
|
||||
elif input_str[j] == old_surr_l and not escaped:
|
||||
count += 1
|
||||
escaped = False
|
||||
j += 1
|
||||
|
||||
if count == 0:
|
||||
assert j < n
|
||||
assert input_str[start] == old_surr_l
|
||||
assert input_str[j] == old_surr_r
|
||||
inner_content = input_str[start + 1 : j]
|
||||
# Replace the content with new pattern
|
||||
result += new_inst + new_surr_l + inner_content + new_surr_r
|
||||
i = j + 1
|
||||
continue
|
||||
else:
|
||||
assert count >= 1
|
||||
assert j == n
|
||||
print("Warning: unbalanced surrogate pair in input string")
|
||||
result += new_inst + new_surr_l
|
||||
i = start + 1
|
||||
continue
|
||||
else:
|
||||
result += input_str[i:start]
|
||||
i = start
|
||||
|
||||
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
||||
return _change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def _find_substring_positions(string, substring):
|
||||
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
||||
return positions
|
||||
|
||||
|
||||
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
pos = _find_substring_positions(input_str, old_inst + old_surr_l)
|
||||
res = list(input_str)
|
||||
for p in pos[::-1]:
|
||||
res[p:] = list(
|
||||
_change(
|
||||
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
|
||||
)
|
||||
)
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
|
||||
def remove_style(input_str: str) -> str:
|
||||
input_str = change_all(input_str, r"\bm", r" ", r"{", r"}", r"", r" ")
|
||||
input_str = change_all(input_str, r"\boldsymbol", r" ", r"{", r"}", r"", r" ")
|
||||
input_str = change_all(input_str, r"\textit", r" ", r"{", r"}", r"", r" ")
|
||||
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
|
||||
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
|
||||
input_str = change_all(input_str, r"\mathbf", r" ", r"{", r"}", r"", r" ")
|
||||
output_str = input_str.strip()
|
||||
return output_str
|
||||
|
||||
|
||||
def add_newlines(latex_str: str) -> str:
|
||||
"""
|
||||
Adds newlines to a LaTeX string based on specific patterns, ensuring no
|
||||
duplicate newlines are added around begin/end environments.
|
||||
- After \\ (if not already followed by newline)
|
||||
- Before \\begin{...} (if not already preceded by newline)
|
||||
- After \\begin{...} (if not already followed by newline)
|
||||
- Before \\end{...} (if not already preceded by newline)
|
||||
- After \\end{...} (if not already followed by newline)
|
||||
|
||||
Args:
|
||||
latex_str: The input LaTeX string.
|
||||
|
||||
Returns:
|
||||
The LaTeX string with added newlines, avoiding duplicates.
|
||||
"""
|
||||
processed_str = latex_str
|
||||
|
||||
# 1. Replace whitespace around \begin{...} with \n...\n
|
||||
# \s* matches zero or more whitespace characters (space, tab, newline)
|
||||
# Captures the \begin{...} part in group 1 (\g<1>)
|
||||
processed_str = re.sub(r"\s*(\\begin\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
|
||||
|
||||
# 2. Replace whitespace around \end{...} with \n...\n
|
||||
# Same logic as for \begin
|
||||
processed_str = re.sub(r"\s*(\\end\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
|
||||
|
||||
# 3. Add newline after \\ (if not already followed by newline)
|
||||
processed_str = re.sub(r"\\\\(?!\n| )|\\\\ ", r"\\\\\n", processed_str)
|
||||
|
||||
# 4. Cleanup: Collapse multiple consecutive newlines into a single newline.
|
||||
# This handles cases where the replacements above might have created \n\n.
|
||||
processed_str = re.sub(r'\n{2,}', '\n', processed_str)
|
||||
|
||||
# Remove leading/trailing whitespace (including potential single newlines
|
||||
# at the very start/end resulting from the replacements) from the entire result.
|
||||
return processed_str.strip()
|
||||
5
texteller/utils/misc.py
Normal file
5
texteller/utils/misc.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
def lines_dedent(s: str) -> str:
|
||||
return dedent(s).strip()
|
||||
52
texteller/utils/path.py
Normal file
52
texteller/utils/path.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from texteller.logger import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
|
||||
def resolve_path(path: str | Path) -> str:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
return str(path.expanduser().resolve())
|
||||
|
||||
|
||||
def touch(path: str | Path) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
path.touch(exist_ok=True)
|
||||
|
||||
|
||||
def mkdir(path: str | Path) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def rmfile(path: str | Path) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
path.unlink(missing_ok=False)
|
||||
|
||||
|
||||
def rmdir(path: str | Path, mode: Literal["empty", "recursive"] = "empty") -> None:
|
||||
"""Remove a directory.
|
||||
|
||||
Args:
|
||||
path: Path to directory to remove
|
||||
mode: "empty" to only remove empty directories, "all" to recursively remove all contents
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
if mode == "empty":
|
||||
path.rmdir()
|
||||
_logger.info(f"Removed empty directory: {path}")
|
||||
elif mode == "recursive":
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(path)
|
||||
_logger.info(f"Recursively removed directory and all contents: {path}")
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}. Must be 'empty' or 'all'")
|
||||
Reference in New Issue
Block a user