[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

View File

@@ -0,0 +1,26 @@
from .device import get_device, cuda_available, mps_available, str2device
from .image import readimgs, transform
from .latex import change_all, remove_style, add_newlines
from .path import mkdir, resolve_path
from .misc import lines_dedent
from .bbox import mask_img, bbox_merge, split_conflict, slice_from_image, draw_bboxes
__all__ = [
"get_device",
"cuda_available",
"mps_available",
"str2device",
"readimgs",
"transform",
"change_all",
"remove_style",
"add_newlines",
"mkdir",
"resolve_path",
"lines_dedent",
"mask_img",
"bbox_merge",
"split_conflict",
"slice_from_image",
"draw_bboxes",
]

142
texteller/utils/bbox.py Normal file
View File

@@ -0,0 +1,142 @@
import heapq
import os
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from texteller.types import Bbox
_MAXV = 999999999
def mask_img(img, bboxes: list[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: list[Bbox]) -> list[Bbox]:
if len(sorted_bboxes) == 0:
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(_MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: list[Bbox], latex_bboxes: list[Bbox]) -> list[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while len(bboxes) > 0:
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None,
),
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: list[Bbox]) -> list[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y : y + h, x : x + w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def draw_bboxes(img: Image.Image, bboxes: list[Bbox], name="annotated_image.png"):
curr_work_dir = Path(os.getcwd())
log_dir = curr_work_dir / "logs"
log_dir.mkdir(exist_ok=True)
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(log_dir / name)

41
texteller/utils/device.py Normal file
View File

@@ -0,0 +1,41 @@
from typing import Literal
import torch
def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device:
if device_str == "cpu":
return torch.device("cpu")
elif device_str == "cuda":
return torch.device("cuda")
elif device_str == "mps":
return torch.device("mps")
else:
raise ValueError(f"Invalid device: {device_str}")
def get_device(device_index: int = None) -> torch.device:
"""
Automatically detect the best available device for inference.
Args:
device_index: The index of GPU device to use if multiple are available.
Defaults to None, which uses the first available GPU.
Returns:
torch.device: Selected device for model inference.
"""
if cuda_available():
return str2device("cuda")
elif mps_available():
return str2device("mps")
else:
return str2device("cpu")
def cuda_available() -> bool:
return torch.cuda.is_available()
def mps_available() -> bool:
return torch.backends.mps.is_available()

121
texteller/utils/image.py Normal file
View File

@@ -0,0 +1,121 @@
from collections import Counter
from typing import List, Union
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2
from texteller.constants import (
FIXED_IMG_SIZE,
IMG_CHANNELS,
IMAGE_MEAN,
IMAGE_STD,
)
from texteller.logger import get_logger
_logger = get_logger()
def readimgs(image_paths: list[str]) -> list[np.ndarray]:
"""
Read and preprocess a list of images from their file paths.
This function reads each image from the provided paths, handles different
bit depths (converting 16-bit to 8-bit if necessary), and normalizes color
channels to RGB format regardless of the original color space (BGR, BGRA,
or grayscale).
Args:
image_paths (list[str]): A list of file paths to the images to be read.
Returns:
list[np.ndarray]: A list of NumPy arrays containing the preprocessed images
in RGB format. Images that could not be read are skipped.
"""
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
raise ValueError(f"Image at {path} could not be read.")
if image.dtype == np.uint16:
_logger.warning(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images
def trim_white_border(image: np.ndarray) -> np.ndarray:
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
general_transform_pipeline = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True),
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True,
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
]
)
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
]
images = [trim_white_border(image) for image in images]
images = [general_transform_pipeline(image) for image in images]
images = padding(images, FIXED_IMG_SIZE)
return images

128
texteller/utils/latex.py Normal file
View File

@@ -0,0 +1,128 @@
import re
def _change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i : i + len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1 : j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return _change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def _find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = _find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(
_change(
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
)
)
res = ''.join(res)
return res
def remove_style(input_str: str) -> str:
input_str = change_all(input_str, r"\bm", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\boldsymbol", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textit", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\mathbf", r" ", r"{", r"}", r"", r" ")
output_str = input_str.strip()
return output_str
def add_newlines(latex_str: str) -> str:
"""
Adds newlines to a LaTeX string based on specific patterns, ensuring no
duplicate newlines are added around begin/end environments.
- After \\ (if not already followed by newline)
- Before \\begin{...} (if not already preceded by newline)
- After \\begin{...} (if not already followed by newline)
- Before \\end{...} (if not already preceded by newline)
- After \\end{...} (if not already followed by newline)
Args:
latex_str: The input LaTeX string.
Returns:
The LaTeX string with added newlines, avoiding duplicates.
"""
processed_str = latex_str
# 1. Replace whitespace around \begin{...} with \n...\n
# \s* matches zero or more whitespace characters (space, tab, newline)
# Captures the \begin{...} part in group 1 (\g<1>)
processed_str = re.sub(r"\s*(\\begin\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
# 2. Replace whitespace around \end{...} with \n...\n
# Same logic as for \begin
processed_str = re.sub(r"\s*(\\end\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
# 3. Add newline after \\ (if not already followed by newline)
processed_str = re.sub(r"\\\\(?!\n| )|\\\\ ", r"\\\\\n", processed_str)
# 4. Cleanup: Collapse multiple consecutive newlines into a single newline.
# This handles cases where the replacements above might have created \n\n.
processed_str = re.sub(r'\n{2,}', '\n', processed_str)
# Remove leading/trailing whitespace (including potential single newlines
# at the very start/end resulting from the replacements) from the entire result.
return processed_str.strip()

5
texteller/utils/misc.py Normal file
View File

@@ -0,0 +1,5 @@
from textwrap import dedent
def lines_dedent(s: str) -> str:
return dedent(s).strip()

52
texteller/utils/path.py Normal file
View File

@@ -0,0 +1,52 @@
from pathlib import Path
from typing import Literal
from texteller.logger import get_logger
_logger = get_logger(__name__)
def resolve_path(path: str | Path) -> str:
if isinstance(path, str):
path = Path(path)
return str(path.expanduser().resolve())
def touch(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.touch(exist_ok=True)
def mkdir(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
def rmfile(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.unlink(missing_ok=False)
def rmdir(path: str | Path, mode: Literal["empty", "recursive"] = "empty") -> None:
"""Remove a directory.
Args:
path: Path to directory to remove
mode: "empty" to only remove empty directories, "all" to recursively remove all contents
"""
if isinstance(path, str):
path = Path(path)
if mode == "empty":
path.rmdir()
_logger.info(f"Removed empty directory: {path}")
elif mode == "recursive":
import shutil
shutil.rmtree(path)
_logger.info(f"Recursively removed directory and all contents: {path}")
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'empty' or 'all'")