[chore] exclude paddleocr directory from pre-commit hooks

This commit is contained in:
三洋三洋
2025-02-28 19:56:49 +08:00
parent a8a005ae10
commit 3d546f9993
130 changed files with 592 additions and 739 deletions

View File

@@ -0,0 +1,60 @@
import torch
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform, inference_transform
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT
and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)

View File

@@ -0,0 +1,26 @@
import cv2
import numpy as np
from typing import List
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images

View File

@@ -0,0 +1,49 @@
import torch
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens=None,
) -> List[str]:
if imgs == []:
return []
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values.to(model.device), generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -0,0 +1,25 @@
import evaluate
import numpy as np
import os
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load(
'google_bleu'
) # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)

View File

@@ -0,0 +1,152 @@
from augraphy import *
import random
def ocr_augmentation_pipeline():
pre_phase = []
ink_phase = [
InkColorSwap(
ink_swap_color="random",
ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120),
ink_swap_min_height_range=(2, 3),
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
# p=0.2
p=0.4,
),
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
line_gradient_range=(32, 255),
line_gradient_direction=(0, 2),
line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
# p=0.2
p=0.4,
),
# ============================
OneOf(
[
Dithering(
dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
# p=0.2
p=0.4,
),
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4),
text_fade_range=(0, 2),
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="perlin",
# p=0.2
p=0.4,
),
# ============================
]
paper_phase = [
NoiseTexturize( # tested
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
# p=0.2
p=0.4,
),
BrightnessTexturize( # tested
texturize_range=(0.9, 0.99),
deviation=0.03,
# p=0.2
p=0.4,
),
]
post_phase = [
ColorShift( # tested
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
# p=0.2
p=0.4,
),
DirtyDrum( # tested
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
noise_intensity=random.uniform(0.6, 0.95),
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
# p=0.2
p=0.4,
),
# =====================================
OneOf(
[
LightingGradient(
light_position=None,
direction=None,
max_brightness=255,
min_brightness=0,
mode="gaussian",
linear_decay_rate=None,
transparency=None,
),
Brightness(
brightness_range=(0.9, 1.1),
min_brightness=0,
min_brightness_value=(120, 150),
),
Gamma(
gamma_range=(0.9, 1.1),
),
],
# p=0.2
p=0.4,
),
# =====================================
# =====================================
OneOf(
[
SubtleNoise(
subtle_range=random.randint(5, 10),
),
Jpeg(
quality_range=(70, 95),
),
],
# p=0.2
p=0.4,
),
# =====================================
]
pipeline = AugraphyPipeline(
ink_phase=ink_phase,
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False,
)
return pipeline

View File

@@ -0,0 +1,184 @@
import re
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i : i + len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1 : j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def rm_dollar_surr(content):
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
matches = pattern.findall(content)
for match in matches:
if not re.match(r'\\[a-zA-Z]+', match):
new_match = match.strip('$')
content = content.replace(match, ' ' + new_match + ' ')
return content
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(
change(
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
)
)
res = ''.join(res)
return res
def to_katex(formula: str) -> str:
res = formula
# remove mbox surrounding
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
# remove hbox surrounding
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
# remove raise surrounding
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
# remove makebox
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
# remove vbox surrounding, scalebox surrounding
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\tiny',
]
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr',
]
for origin_ins in origin_instructions:
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
# remove multiple spaces
res = re.sub(r'(\\,){1,}', ' ', res)
res = re.sub(r'(\\!){1,}', ' ', res)
res = re.sub(r'(\\;){1,}', ' ', res)
res = re.sub(r'(\\:){1,}', ' ', res)
res = re.sub(r'\\vspace\{.*?}', '', res)
# merge consecutive text
def merge_texts(match):
texts = match.group(0)
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
return f'\\text{{{merged_content}}}'
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
res = res.replace(r'\bf ', '')
res = rm_dollar_surr(res)
# remove extra spaces (keeping only one)
res = re.sub(r' +', ' ', res)
return res.strip()

View File

@@ -0,0 +1,177 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from collections import Counter
from ...globals import (
IMG_CHANNELS,
FIXED_IMG_SIZE,
IMAGE_MEAN,
IMAGE_STD,
MAX_RESIZE_RATIO,
MIN_RESIZE_RATIO,
)
from .ocr_aug import ocr_augmentation_pipeline
# train_pipeline = default_augraphy_pipeline(scan_only=True)
train_pipeline = ocr_augmentation_pipeline()
general_transform_pipeline = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True,
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage()
]
)
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if pad_height_size + image.shape[0] < 30:
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if pad_width_size + image.shape[1] < 30:
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255),
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
) # 抗锯齿
for img, r in zip(images, ratios)
]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
)
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
if random.random() < 0.2:
image = rotate(image, -5, 5)
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
image = train_pipeline(image)
return image
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
images = [trim_white_border(image) for image in images]
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
]
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images