[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

View File

@@ -0,0 +1,3 @@
from .texteller import TexTeller
__all__ = ['TexTeller']

View File

@@ -1,89 +0,0 @@
import os
from PIL import Image, ImageDraw
from typing import List
from pathlib import Path
class Point:
def __init__(self, x: int, y: int):
self.x = int(x)
self.y = int(y)
def __repr__(self) -> str:
return f"Point(x={self.x}, y={self.y})"
class Bbox:
THREADHOLD = 0.4
def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None):
self.p = Point(x, y)
self.h = int(h)
self.w = int(w)
self.label = label
self.confidence = confidence
self.content = content
@property
def ul_point(self) -> Point:
return self.p
@property
def ur_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y)
@property
def ll_point(self) -> Point:
return Point(self.p.x, self.p.y + self.h)
@property
def lr_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y + self.h)
def same_row(self, other) -> bool:
if (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) or (
self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y
):
return True
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
return False
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
def __lt__(self, other) -> bool:
'''
from top to bottom, from left to right
'''
if not self.same_row(other):
return self.p.y < other.p.y
else:
return self.p.x < other.p.x
def __repr__(self) -> str:
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
curr_work_dir = Path(os.getcwd())
log_dir = curr_work_dir / "logs"
log_dir.mkdir(exist_ok=True)
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(log_dir / name)

View File

@@ -1,226 +0,0 @@
import os
import time
import yaml
import numpy as np
import cv2
from tqdm import tqdm
from typing import List
from .preprocess import Compose
from .Bbox import Bbox
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'PPYOLOE',
'RCNN',
'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
'HRNet',
'DETR',
}
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(
image,
"{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
2,
)
return image
def predict_image(imgsave_dir, infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
first_image_skipped = False
total_time = 0
num_images = 0
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
# Start timing
start_time = time.time()
outputs = predictor.run(output_names=None, input_feed=inputs)
# Stop timing
end_time = time.time()
inference_time = end_time - start_time
if not first_image_skipped:
first_image_skipped = True
else:
total_time += inference_time
num_images += 1
print(
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
if len(subimg) == 0:
continue
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_mask = img.copy()
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
cv2.rectangle(
img_with_mask,
(int(xmin), int(ymin)),
(int(xmax), int(ymax)),
(255, 255, 255),
-1,
) # 盖白
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True)
draw_box_dir = os.path.join(output_dir, 'draw_box')
mask_white_dir = os.path.join(output_dir, 'mask_white')
os.makedirs(draw_box_dir, exist_ok=True)
os.makedirs(mask_white_dir, exist_ok=True)
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
cv2.imwrite(output_file_mask, img_with_mask)
cv2.imwrite(output_file_bbox, img_with_bbox)
avg_time_per_image = total_time / num_images if num_images > 0 else 0
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
print("ErrorImgs:")
print(errImgList)
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = infer_config.label_list[int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > infer_config.draw_threshold:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -1,27 +0,0 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: DETR
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 1600
- 1600
type: Resize
- mean:
- 0.0
- 0.0
- 0.0
norm_type: none
std:
- 1.0
- 1.0
- 1.0
type: NormalizeImage
- type: Permute
label_list:
- isolated
- embedding

View File

@@ -1,485 +0,0 @@
import numpy as np
import cv2
import copy
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(im.shape[:2], dtype=np.float32),
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(
self,
):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
class Pad(object):
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.fill_value = fill_value
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
h, w = self.size
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
im = canvas
return im, im_info
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
dst_dir = np.array([0.0, dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
class WarpAffine(object):
"""Warp affine the image"""
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info
# keypoint preprocess
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta)
+ 0.5 * size_input[1] * np.sin(theta)
+ 0.5 * size_target[0]
)
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta)
- 0.5 * size_input[1] * np.cos(theta)
+ 0.5 * size_target[1]
)
return matrix
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize, use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.0
scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
)
image = cv2.warpAffine(
image,
trans,
(int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
else:
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans,
(int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
return image, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs['image'] = img
return inputs

View File

@@ -1,23 +0,0 @@
# Formula image(grayscale) mean and variance
IMAGE_MEAN = 0.9545467
IMAGE_STD = 0.15394445
# Vocabulary size for TexTeller
VOCAB_SIZE = 15000
# Fixed size for input image for TexTeller
FIXED_IMG_SIZE = 448
# Image channel for TexTeller
IMG_CHANNELS = 1 # grayscale image
# Max size of token for embedding
MAX_TOKEN_SIZE = 1024
# Scaling ratio for random resizing when training
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# Minimum height and width for input image for TexTeller
MIN_HEIGHT = 12
MIN_WIDTH = 30

View File

@@ -1,43 +0,0 @@
from pathlib import Path
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(
Path(__file__).resolve().parent / "config.json"
)
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
if model_path is None or model_path == 'default':
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(
cls.REPO_NAME,
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None or tokenizer_path == 'default':
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))

View File

@@ -1,168 +0,0 @@
{
"_name_or_path": "OleehyO/TexTeller",
"architectures": [
"VisionEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": 768,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.1,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layernorm_embedding": true,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 1024,
"min_length": 0,
"model_type": "trocr",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": false,
"use_learned_position_embeddings": true,
"vocab_size": 15000
},
"encoder": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_probs_dropout_prob": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"encoder_stride": 16,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 448,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-12,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "vit",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 16,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"qkv_bias": false,
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false
},
"is_encoder_decoder": true,
"model_type": "vision-encoder-decoder",
"tie_word_embeddings": false,
"transformers_version": "4.41.2",
"use_cache": true
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.2 KiB

View File

@@ -1,35 +0,0 @@
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -1,114 +0,0 @@
import os
from functools import partial
from pathlib import Path
from datasets import load_dataset
from transformers import (
Trainer,
TrainingArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
GenerationConfig,
)
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import (
tokenize_fn,
collate_fn,
img_train_transform,
img_inf_transform,
filter_fn,
)
from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
eval_config['predict_with_generate'] = True
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
eval_config['generation_config'] = generate_config
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
)
eval_res = trainer.evaluate()
print(eval_res)
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
dataset = dataset.filter(
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
)
# Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch
model = TexTeller()
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
# +e.g.
# +model = TexTeller.from_pretrained(
# + '/path/to/your/model_checkpoint'
# +)
enable_train = True
enable_evaluate = False
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate and len(eval_dataset) > 0:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -1,31 +0,0 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -1,60 +0,0 @@
import torch
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform, inference_transform
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT
and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)

View File

@@ -1,26 +0,0 @@
import cv2
import numpy as np
from typing import List
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images

View File

@@ -1,116 +0,0 @@
import torch
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria
from typing import List, Union
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
"""
Stops generation efficiently if any n-gram repeats.
This criteria maintains a set of encountered n-grams.
At each step, it checks if the *latest* n-gram is already in the set.
If yes, it stops generation. If no, it adds the n-gram to the set.
"""
def __init__(self, n: int):
"""
Args:
n (int): The size of the n-gram to check for repetition.
"""
if n <= 0:
raise ValueError("n-gram size 'n' must be positive.")
self.n = n
# Stores tuples of token IDs representing seen n-grams
self.seen_ngrams = set()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores.
Return:
`bool`: `True` if generation should stop, `False` otherwise.
"""
batch_size, seq_length = input_ids.shape
# Need at least n tokens to form the first n-gram
if seq_length < self.n:
return False
# --- Efficient Check ---
# Consider only the first sequence in the batch for simplicity
if batch_size > 1:
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
# For now, we'll focus on the first sequence.
pass # No warning needed every step, maybe once in __init__ if needed.
sequence = input_ids[0] # Get the first sequence
# Get the latest n-gram (the one ending at the last token)
last_ngram_tensor = sequence[-self.n :]
# Convert to a hashable tuple for set storage and lookup
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
# Check if this n-gram has been seen before *at any prior step*
if last_ngram_tuple in self.seen_ngrams:
return True # Stop generation
else:
# It's a new n-gram, add it to the set and continue
self.seen_ngrams.add(last_ngram_tuple)
return False # Continue generation
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens=None,
) -> List[str]:
if imgs == []:
return []
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
# no_repeat_ngram_size=10,
)
pred = model.generate(
pixel_values.to(model.device),
generation_config=generate_config,
# stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)],
)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -1,698 +0,0 @@
#!/usr/bin/env python3
"""
Python implementation of tex-fmt, a LaTeX formatter.
Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
"""
import re
import argparse
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict, Set
# Constants
LINE_END = "\n"
ITEM = "\\item"
DOC_BEGIN = "\\begin{document}"
DOC_END = "\\end{document}"
ENV_BEGIN = "\\begin{"
ENV_END = "\\end{"
TEXT_LINE_START = ""
COMMENT_LINE_START = "% "
# Opening and closing delimiters
OPENS = ['{', '(', '[']
CLOSES = ['}', ')', ']']
# Names of LaTeX verbatim environments
VERBATIMS = ["verbatim", "Verbatim", "lstlisting", "minted", "comment"]
VERBATIMS_BEGIN = [f"\\begin{{{v}}}" for v in VERBATIMS]
VERBATIMS_END = [f"\\end{{{v}}}" for v in VERBATIMS]
# Regex patterns for sectioning commands
SPLITTING = [
r"\\begin\{",
r"\\end\{",
r"\\item(?:$|[^a-zA-Z])",
r"\\(?:sub){0,2}section\*?\{",
r"\\chapter\*?\{",
r"\\part\*?\{",
]
# Compiled regexes
SPLITTING_STRING = f"({'|'.join(SPLITTING)})"
RE_NEWLINES = re.compile(f"{LINE_END}{LINE_END}({LINE_END})+")
RE_TRAIL = re.compile(f" +{LINE_END}")
RE_SPLITTING = re.compile(SPLITTING_STRING)
RE_SPLITTING_SHARED_LINE = re.compile(f"(?:\\S.*?)(?:{SPLITTING_STRING}.*)")
RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTING_STRING}.*)")
@dataclass
class Args:
"""Command line arguments and configuration."""
tabchar: str = " "
tabsize: int = 4
wrap: bool = False
wraplen: int = 80
wrapmin: int = 40
lists: List[str] = None
verbosity: int = 0
def __post_init__(self):
if self.lists is None:
self.lists = []
@dataclass
class Ignore:
"""Information on the ignored state of a line."""
actual: bool = False
visual: bool = False
@classmethod
def new(cls):
return cls(False, False)
@dataclass
class Verbatim:
"""Information on the verbatim state of a line."""
actual: int = 0
visual: bool = False
@classmethod
def new(cls):
return cls(0, False)
@dataclass
class Indent:
"""Information on the indentation state of a line."""
actual: int = 0
visual: int = 0
@classmethod
def new(cls):
return cls(0, 0)
@dataclass
class State:
"""Information on the current state during formatting."""
linum_old: int = 1
linum_new: int = 1
ignore: Ignore = None
indent: Indent = None
verbatim: Verbatim = None
linum_last_zero_indent: int = 1
def __post_init__(self):
if self.ignore is None:
self.ignore = Ignore.new()
if self.indent is None:
self.indent = Indent.new()
if self.verbatim is None:
self.verbatim = Verbatim.new()
@dataclass
class Pattern:
"""Record whether a line contains certain patterns."""
contains_env_begin: bool = False
contains_env_end: bool = False
contains_item: bool = False
contains_splitting: bool = False
contains_comment: bool = False
@classmethod
def new(cls, s: str):
"""Check if a string contains patterns."""
if RE_SPLITTING.search(s):
return cls(
contains_env_begin=ENV_BEGIN in s,
contains_env_end=ENV_END in s,
contains_item=ITEM in s,
contains_splitting=True,
contains_comment='%' in s,
)
else:
return cls(
contains_env_begin=False,
contains_env_end=False,
contains_item=False,
contains_splitting=False,
contains_comment='%' in s,
)
@dataclass
class Log:
"""Log message."""
level: str
file: str
message: str
linum_new: Optional[int] = None
linum_old: Optional[int] = None
line: Optional[str] = None
def find_comment_index(line: str, pattern: Pattern) -> Optional[int]:
"""Find the index of a comment in a line."""
if not pattern.contains_comment:
return None
in_command = False
for i, c in enumerate(line):
if c == '\\':
in_command = True
elif in_command and not c.isalpha():
in_command = False
elif c == '%' and not in_command:
return i
return None
def contains_ignore_skip(line: str) -> bool:
"""Check if a line contains a skip directive."""
return line.endswith("% tex-fmt: skip")
def contains_ignore_begin(line: str) -> bool:
"""Check if a line contains the start of an ignore block."""
return line.endswith("% tex-fmt: off")
def contains_ignore_end(line: str) -> bool:
"""Check if a line contains the end of an ignore block."""
return line.endswith("% tex-fmt: on")
def get_ignore(line: str, state: State, logs: List[Log], file: str, warn: bool) -> Ignore:
"""Determine whether a line should be ignored."""
skip = contains_ignore_skip(line)
begin = contains_ignore_begin(line)
end = contains_ignore_end(line)
if skip:
actual = state.ignore.actual
visual = True
elif begin:
actual = True
visual = True
if warn and state.ignore.actual:
logs.append(
Log(
level="WARN",
file=file,
message="Cannot begin ignore block:",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
elif end:
actual = False
visual = True
if warn and not state.ignore.actual:
logs.append(
Log(
level="WARN",
file=file,
message="No ignore block to end.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
else:
actual = state.ignore.actual
visual = state.ignore.actual
return Ignore(actual=actual, visual=visual)
def get_verbatim_diff(line: str, pattern: Pattern) -> int:
"""Calculate total verbatim depth change."""
if pattern.contains_env_begin and any(r in line for r in VERBATIMS_BEGIN):
return 1
elif pattern.contains_env_end and any(r in line for r in VERBATIMS_END):
return -1
else:
return 0
def get_verbatim(
line: str, state: State, logs: List[Log], file: str, warn: bool, pattern: Pattern
) -> Verbatim:
"""Determine whether a line is in a verbatim environment."""
diff = get_verbatim_diff(line, pattern)
actual = state.verbatim.actual + diff
visual = actual > 0 or state.verbatim.actual > 0
if warn and actual < 0:
logs.append(
Log(
level="WARN",
file=file,
message="Verbatim count is negative.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return Verbatim(actual=actual, visual=visual)
def get_diff(line: str, pattern: Pattern, lists_begin: List[str], lists_end: List[str]) -> int:
"""Calculate total indentation change due to the current line."""
diff = 0
# Other environments get single indents
if pattern.contains_env_begin and ENV_BEGIN in line:
# Documents get no global indentation
if DOC_BEGIN in line:
return 0
diff += 1
diff += 1 if any(r in line for r in lists_begin) else 0
elif pattern.contains_env_end and ENV_END in line:
# Documents get no global indentation
if DOC_END in line:
return 0
diff -= 1
diff -= 1 if any(r in line for r in lists_end) else 0
# Indent for delimiters
for c in line:
if c in OPENS:
diff += 1
elif c in CLOSES:
diff -= 1
return diff
def get_back(line: str, pattern: Pattern, state: State, lists_end: List[str]) -> int:
"""Calculate dedentation for the current line."""
# Only need to dedent if indentation is present
if state.indent.actual == 0:
return 0
if pattern.contains_env_end and ENV_END in line:
# Documents get no global indentation
if DOC_END in line:
return 0
# List environments get double indents for indenting items
for r in lists_end:
if r in line:
return 2
return 1
# Items get dedented
if pattern.contains_item and ITEM in line:
return 1
return 0
def get_indent(
line: str,
prev_indent: Indent,
pattern: Pattern,
state: State,
lists_begin: List[str],
lists_end: List[str],
) -> Indent:
"""Calculate the indent for a line."""
diff = get_diff(line, pattern, lists_begin, lists_end)
back = get_back(line, pattern, state, lists_end)
actual = prev_indent.actual + diff
visual = max(0, prev_indent.actual - back)
return Indent(actual=actual, visual=visual)
def calculate_indent(
line: str,
state: State,
logs: List[Log],
file: str,
args: Args,
pattern: Pattern,
lists_begin: List[str],
lists_end: List[str],
) -> Indent:
"""Calculate the indent for a line and update the state."""
indent = get_indent(line, state.indent, pattern, state, lists_begin, lists_end)
# Update the state
state.indent = indent
# Record the last line with zero indent
if indent.visual == 0:
state.linum_last_zero_indent = state.linum_new
return indent
def apply_indent(line: str, indent: Indent, args: Args, indent_char: str) -> str:
"""Apply indentation to a line."""
if not line.strip():
return ""
indent_str = indent_char * (indent.visual * args.tabsize)
return indent_str + line.lstrip()
def needs_wrap(line: str, indent_length: int, args: Args) -> bool:
"""Check if a line needs wrapping."""
return args.wrap and (len(line) + indent_length > args.wraplen)
def find_wrap_point(line: str, indent_length: int, args: Args) -> Optional[int]:
"""Find the best place to break a long line."""
wrap_point = None
after_char = False
prev_char = None
line_width = 0
wrap_boundary = args.wrapmin - indent_length
for i, c in enumerate(line):
line_width += 1
if line_width > wrap_boundary and wrap_point is not None:
break
if c == ' ' and prev_char != '\\':
if after_char:
wrap_point = i
elif c != '%':
after_char = True
prev_char = c
return wrap_point
def apply_wrap(
line: str,
indent_length: int,
state: State,
file: str,
args: Args,
logs: List[Log],
pattern: Pattern,
) -> Optional[List[str]]:
"""Wrap a long line into a short prefix and a suffix."""
if args.verbosity >= 3: # Trace level
logs.append(
Log(
level="TRACE",
file=file,
message="Wrapping long line.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
wrap_point = find_wrap_point(line, indent_length, args)
comment_index = find_comment_index(line, pattern)
if wrap_point is None or wrap_point > args.wraplen:
logs.append(
Log(
level="WARN",
file=file,
message="Line cannot be wrapped.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return None
this_line = line[:wrap_point]
if comment_index is not None and wrap_point > comment_index:
next_line_start = COMMENT_LINE_START
else:
next_line_start = TEXT_LINE_START
next_line = line[wrap_point + 1 :]
return [this_line, next_line_start, next_line]
def needs_split(line: str, pattern: Pattern) -> bool:
"""Check if line contains content which should be split onto a new line."""
# Check if we should format this line and if we've matched an environment
contains_splittable_env = (
pattern.contains_splitting and RE_SPLITTING_SHARED_LINE.search(line) is not None
)
# If we're not ignoring and we've matched an environment...
if contains_splittable_env:
# Return True if the comment index is None (which implies the split point must be in text),
# otherwise compare the index of the comment with the split point
comment_index = find_comment_index(line, pattern)
if comment_index is None:
return True
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
if match and match.start(2) > comment_index:
# If split point is past the comment index, don't split
return False
else:
# Otherwise, split point is before comment and we do split
return True
else:
# If ignoring or didn't match an environment, don't need a new line
return False
def split_line(line: str, state: State, file: str, args: Args, logs: List[Log]) -> Tuple[str, str]:
"""Ensure lines are split correctly."""
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
if not match:
return line, ""
prev = match.group('prev')
rest = match.group('env')
if args.verbosity >= 3: # Trace level
logs.append(
Log(
level="TRACE",
file=file,
message="Placing environment on new line.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return prev, rest
def set_ignore_and_report(
line: str, temp_state: State, logs: List[Log], file: str, pattern: Pattern
) -> bool:
"""Sets the ignore and verbatim flags in the given State based on line and returns whether line should be ignored."""
temp_state.ignore = get_ignore(line, temp_state, logs, file, True)
temp_state.verbatim = get_verbatim(line, temp_state, logs, file, True, pattern)
return temp_state.verbatim.visual or temp_state.ignore.visual
def clean_text(text: str, args: Args) -> str:
"""Cleans the given text by removing extra line breaks and trailing spaces."""
# Remove extra newlines
text = RE_NEWLINES.sub(f"{LINE_END}{LINE_END}", text)
# Remove tabs if they shouldn't be used
if args.tabchar != '\t':
text = text.replace('\t', ' ' * args.tabsize)
# Remove trailing spaces
text = RE_TRAIL.sub(LINE_END, text)
return text
def remove_trailing_spaces(text: str) -> str:
"""Remove trailing spaces from line endings."""
return RE_TRAIL.sub(LINE_END, text)
def remove_trailing_blank_lines(text: str) -> str:
"""Remove trailing blank lines from file."""
return text.rstrip() + LINE_END
def indents_return_to_zero(state: State) -> bool:
"""Check if indentation returns to zero at the end of the file."""
return state.indent.actual == 0
def format_latex(
old_text: str, file: str = "input.tex", args: Optional[Args] = None
) -> Tuple[str, List[Log]]:
"""Central function to format a LaTeX string."""
if args is None:
args = Args()
logs = []
logs.append(Log(level="INFO", file=file, message="Formatting started."))
# Clean the source file
old_text = clean_text(old_text, args)
old_lines = list(enumerate(old_text.splitlines(), 1))
# Initialize
state = State()
queue = []
new_text = ""
# Select the character used for indentation
indent_char = '\t' if args.tabchar == '\t' else ' '
# Get any extra environments to be indented as lists
lists_begin = [f"\\begin{{{l}}}" for l in args.lists]
lists_end = [f"\\end{{{l}}}" for l in args.lists]
while True:
if queue:
linum_old, line = queue.pop(0)
# Read the patterns present on this line
pattern = Pattern.new(line)
# Temporary state for working on this line
temp_state = State(
linum_old=linum_old,
linum_new=state.linum_new,
ignore=Ignore(state.ignore.actual, state.ignore.visual),
indent=Indent(state.indent.actual, state.indent.visual),
verbatim=Verbatim(state.verbatim.actual, state.verbatim.visual),
linum_last_zero_indent=state.linum_last_zero_indent,
)
# If the line should not be ignored...
if not set_ignore_and_report(line, temp_state, logs, file, pattern):
# Check if the line should be split because of a pattern that should begin on a new line
if needs_split(line, pattern):
# Split the line into two...
this_line, next_line = split_line(line, temp_state, file, args, logs)
# ...and queue the second part for formatting
if next_line:
queue.insert(0, (linum_old, next_line))
line = this_line
# Calculate the indent based on the current state and the patterns in the line
indent = calculate_indent(
line, temp_state, logs, file, args, pattern, lists_begin, lists_end
)
indent_length = indent.visual * args.tabsize
# Wrap the line before applying the indent, and loop back if the line needed wrapping
if needs_wrap(line.lstrip(), indent_length, args):
wrapped_lines = apply_wrap(
line.lstrip(), indent_length, temp_state, file, args, logs, pattern
)
if wrapped_lines:
this_line, next_line_start, next_line = wrapped_lines
queue.insert(0, (linum_old, next_line_start + next_line))
queue.insert(0, (linum_old, this_line))
continue
# Lastly, apply the indent if the line didn't need wrapping
line = apply_indent(line, indent, args, indent_char)
# Add line to new text
state = temp_state
new_text += line + LINE_END
state.linum_new += 1
elif old_lines:
linum_old, line = old_lines.pop(0)
queue.append((linum_old, line))
else:
break
if not indents_return_to_zero(state):
msg = f"Indent does not return to zero. Last non-indented line is line {state.linum_last_zero_indent}"
logs.append(Log(level="WARN", file=file, message=msg))
new_text = remove_trailing_spaces(new_text)
new_text = remove_trailing_blank_lines(new_text)
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
return new_text, logs
def main():
"""Command-line entry point."""
parser = argparse.ArgumentParser(description="Format LaTeX files")
parser.add_argument("file", help="LaTeX file to format")
parser.add_argument(
"--tabchar",
choices=["space", "tab"],
default="space",
help="Character to use for indentation",
)
parser.add_argument("--tabsize", type=int, default=4, help="Number of spaces per indent level")
parser.add_argument("--wrap", action="store_true", help="Enable line wrapping")
parser.add_argument("--wraplen", type=int, default=80, help="Maximum line length")
parser.add_argument(
"--wrapmin", type=int, default=40, help="Minimum line length before wrapping"
)
parser.add_argument(
"--lists", nargs="+", default=[], help="Additional environments to indent as lists"
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Increase verbosity")
parser.add_argument("--output", "-o", help="Output file (default: overwrite input)")
args_parsed = parser.parse_args()
# Convert command line args to our Args class
args = Args(
tabchar="\t" if args_parsed.tabchar == "tab" else " ",
tabsize=args_parsed.tabsize,
wrap=args_parsed.wrap,
wraplen=args_parsed.wraplen,
wrapmin=args_parsed.wrapmin,
lists=args_parsed.lists,
verbosity=args_parsed.verbose,
)
# Read input file
with open(args_parsed.file, "r", encoding="utf-8") as f:
text = f.read()
# Format the text
formatted_text, logs = format_latex(text, args_parsed.file, args)
# Print logs if verbose
if args.verbosity > 0:
for log in logs:
if log.linum_new is not None:
print(f"{log.level} {log.file}:{log.linum_new}:{log.linum_old}: {log.message}")
else:
print(f"{log.level} {log.file}: {log.message}")
# Write output
output_file = args_parsed.output or args_parsed.file
with open(output_file, "w", encoding="utf-8") as f:
f.write(formatted_text)
if __name__ == "__main__":
main()

View File

@@ -1,25 +0,0 @@
import evaluate
import numpy as np
import os
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load(
'google_bleu'
) # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)

View File

@@ -1,152 +0,0 @@
from augraphy import *
import random
def ocr_augmentation_pipeline():
pre_phase = []
ink_phase = [
InkColorSwap(
ink_swap_color="random",
ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120),
ink_swap_min_height_range=(2, 3),
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
# p=0.2
p=0.4,
),
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
line_gradient_range=(32, 255),
line_gradient_direction=(0, 2),
line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
# p=0.2
p=0.4,
),
# ============================
OneOf(
[
Dithering(
dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
# p=0.2
p=0.4,
),
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4),
text_fade_range=(0, 2),
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="perlin",
# p=0.2
p=0.4,
),
# ============================
]
paper_phase = [
NoiseTexturize( # tested
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
# p=0.2
p=0.4,
),
BrightnessTexturize( # tested
texturize_range=(0.9, 0.99),
deviation=0.03,
# p=0.2
p=0.4,
),
]
post_phase = [
ColorShift( # tested
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
# p=0.2
p=0.4,
),
DirtyDrum( # tested
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
noise_intensity=random.uniform(0.6, 0.95),
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
# p=0.2
p=0.4,
),
# =====================================
OneOf(
[
LightingGradient(
light_position=None,
direction=None,
max_brightness=255,
min_brightness=0,
mode="gaussian",
linear_decay_rate=None,
transparency=None,
),
Brightness(
brightness_range=(0.9, 1.1),
min_brightness=0,
min_brightness_value=(120, 150),
),
Gamma(
gamma_range=(0.9, 1.1),
),
],
# p=0.2
p=0.4,
),
# =====================================
# =====================================
OneOf(
[
SubtleNoise(
subtle_range=random.randint(5, 10),
),
Jpeg(
quality_range=(70, 95),
),
],
# p=0.2
p=0.4,
),
# =====================================
]
pipeline = AugraphyPipeline(
ink_phase=ink_phase,
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False,
)
return pipeline

View File

@@ -1,194 +0,0 @@
import re
from .latex_formatter import format_latex
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i : i + len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1 : j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def rm_dollar_surr(content):
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
matches = pattern.findall(content)
for match in matches:
if not re.match(r'\\[a-zA-Z]+', match):
new_match = match.strip('$')
content = content.replace(match, ' ' + new_match + ' ')
return content
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(
change(
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
)
)
res = ''.join(res)
return res
def to_katex(formula: str) -> str:
res = formula
# remove mbox surrounding
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
# remove hbox surrounding
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
# remove raise surrounding
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
# remove makebox
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
# remove vbox surrounding, scalebox surrounding
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\tiny',
]
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change_all(res, r'\mathbf', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
# remove bold command
res = change_all(res, r'\bm', r' ', r'{', r'}', r'', r'')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr',
]
for origin_ins in origin_instructions:
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
# remove multiple spaces
res = re.sub(r'(\\,){1,}', ' ', res)
res = re.sub(r'(\\!){1,}', ' ', res)
res = re.sub(r'(\\;){1,}', ' ', res)
res = re.sub(r'(\\:){1,}', ' ', res)
res = re.sub(r'\\vspace\{.*?}', '', res)
# merge consecutive text
def merge_texts(match):
texts = match.group(0)
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
return f'\\text{{{merged_content}}}'
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
res = res.replace(r'\bf ', '')
res = rm_dollar_surr(res)
# remove extra spaces (keeping only one)
res = re.sub(r' +', ' ', res)
# format latex
res = res.strip()
res, logs = format_latex(res)
return res

View File

@@ -1,177 +0,0 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from collections import Counter
from ...globals import (
IMG_CHANNELS,
FIXED_IMG_SIZE,
IMAGE_MEAN,
IMAGE_STD,
MAX_RESIZE_RATIO,
MIN_RESIZE_RATIO,
)
from .ocr_aug import ocr_augmentation_pipeline
# train_pipeline = default_augraphy_pipeline(scan_only=True)
train_pipeline = ocr_augmentation_pipeline()
general_transform_pipeline = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True,
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage()
]
)
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if pad_height_size + image.shape[0] < 30:
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if pad_width_size + image.shape[1] < 30:
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255),
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
) # 抗锯齿
for img, r in zip(images, ratios)
]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
)
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
if random.random() < 0.2:
image = rotate(image, -5, 5)
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
image = train_pipeline(image)
return image
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
images = [trim_white_border(image) for image in images]
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
]
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images

View File

@@ -0,0 +1,48 @@
from pathlib import Path
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
from texteller.constants import (
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE,
VOCAB_SIZE,
)
from texteller.globals import Globals
from texteller.types import TexTellerModel
from texteller.utils import cuda_available
class TexTeller(VisionEncoderDecoderModel):
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
if model_dir is None or model_dir == Globals().repo_name:
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
return ORTModelForVision2Seq.from_pretrained(
Globals().repo_name,
provider="CUDAExecutionProvider"
if cuda_available()
else "CPUExecutionProvider",
)
model_dir = Path(model_dir).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
@classmethod
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
tokenizer_dir = Path(tokenizer_dir).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))

View File

@@ -1,212 +0,0 @@
import os
import re
from pathlib import Path
import numpy as np
class BaseRecLabelDecode(object):
"""Convert between text-label and text-index"""
def __init__(self, character_dict_path=None, use_space_char=False):
cur_path = os.getcwd()
scriptDir = Path(__file__).resolve().parent
os.chdir(scriptDir)
character_dict_path = str(Path(scriptDir / "ppocr_keys_v1.txt"))
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if "arabic" in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
os.chdir(cur_path)
def pred_reverse(self, pred):
pred_re = []
c_current = ""
for c in pred:
if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
if c_current != "":
pred_re.append(c_current)
pred_re.append(c)
c_current = ""
else:
c_current += c
if c_current != "":
pred_re.append(c_current)
return "".join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection == True)[0]
for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff":
c_state = "cn"
elif bool(re.search("[a-zA-Z0-9]", char)):
c_state = "en&num"
else:
c_state = "splitter"
if (
char == "."
and state == "en&num"
and c_i + 1 < len(text)
and bool(re.search("[0-9]", text[c_i + 1]))
): # grouping floting number
c_state = "en&num"
if (
char == "-" and state == "en&num"
): # grouping word with '-', such as 'state-of-the-art'
c_state = "en&num"
if state is None:
state = c_state
if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state
if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
return word_list, word_col_list, state_list
def decode(
self,
text_index,
text_prob=None,
is_remove_duplicate=False,
return_word_box=False,
):
"""convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = "".join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append(
(
text,
np.mean(conf_list).tolist(),
[
len(text_index[batch_idx]),
word_list,
word_col_list,
state_list,
],
)
)
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
"""Convert between text-label and text-index"""
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
assert isinstance(preds, np.ndarray)
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(
preds_idx,
preds_prob,
is_remove_duplicate=True,
return_word_box=return_word_box,
)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ["blank"] + dict_character
return dict_character

View File

@@ -1,221 +0,0 @@
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(
self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type="quad",
**kwargs,
):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow",
"fast",
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours(
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours[: self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours(
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""
box_score_fast: use bbox mean score as the mean score
"""
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
"""
box_score_slow: use polyon mean score as the mean score
"""
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict["maps"]
assert isinstance(pred, np.ndarray)
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel,
)
else:
mask = segmentation[batch_index]
if self.box_type == "poly":
boxes, scores = self.polygons_from_bitmap(pred[batch_index], mask, src_w, src_h)
elif self.box_type == "quad":
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
else:
raise ValueError("box_type can only be one of ['quad', 'poly']")
boxes_batch.append({"points": boxes})
return boxes_batch

View File

@@ -1,186 +0,0 @@
import numpy as np
import cv2
import math
import sys
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if "image_shape" in kwargs:
self.image_shape = kwargs["image_shape"]
self.resize_type = 1
if "keep_ratio" in kwargs:
self.keep_ratio = kwargs["keep_ratio"]
elif "limit_side_len" in kwargs:
self.limit_side_len = kwargs["limit_side_len"]
self.limit_type = kwargs.get("limit_type", "min")
elif "resize_long" in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get("resize_long", 960)
else:
self.limit_side_len = 736
self.limit_type = "min"
def __call__(self, data):
img = data["image"]
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data["image"] = img
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.0
elif self.limit_type == "min":
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.0
elif self.limit_type == "resize_long":
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception("not support limit type, image ")
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except: # noqa: E722
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class NormalizeImage(object):
"""normalize image such as substract mean, divide std"""
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype("float32")
self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
img = data["image"]
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
"""convert hwc image to chw image"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data["image"]
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data["image"] = img.transpose((2, 0, 1))
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list

File diff suppressed because it is too large Load Diff

View File

@@ -1,286 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import sys
import time
import cv2
import numpy as np
# import tools.infer.utility as utility
import utility
from DBPostProcess import DBPostProcess
from operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
from utility import get_logger
def transform(data, ops=None):
"""transform"""
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
logger = get_logger()
class TextDetector(object):
def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm
self.use_onnx = args.use_onnx
postprocess_params = {}
assert self.det_algorithm == "DB"
postprocess_params["name"] = "DBPostProcess"
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
self.preprocess_op = [
DetResizeForTest(
limit_side_len=args.det_limit_side_len, limit_type=args.det_limit_type
),
NormalizeImage(
std=[0.229, 0.224, 0.225],
mean=[0.485, 0.456, 0.406],
scale=1.0 / 255.0,
order="hwc",
),
ToCHWImage(),
KeepKeys(keep_keys=["image", "shape"]),
]
self.postprocess_op = DBPostProcess(**postprocess_params)
(
self.predictor,
self.input_tensor,
self.output_tensors,
self.config,
) = utility.create_predictor(args, "det", logger)
assert self.use_onnx
if self.use_onnx:
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
self.preprocess_op[0] = DetResizeForTest(image_shape=[img_h, img_w])
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def predict(self, img):
ori_im = img.copy()
data = {"image": img}
st = time.time()
if self.args.benchmark:
self.autolog.times.start()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
if self.args.benchmark:
self.autolog.times.stamp()
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = img
outputs = self.predictor.run(self.output_tensors, input_dict)
else:
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.args.benchmark:
self.autolog.times.stamp()
preds = {}
if self.det_algorithm == "EAST":
preds["f_geo"] = outputs[0]
preds["f_score"] = outputs[1]
elif self.det_algorithm == "SAST":
preds["f_border"] = outputs[0]
preds["f_score"] = outputs[1]
preds["f_tco"] = outputs[2]
preds["f_tvo"] = outputs[3]
elif self.det_algorithm in ["DB", "PSE", "DB++"]:
preds["maps"] = outputs[0]
elif self.det_algorithm == "FCE":
for i, output in enumerate(outputs):
preds["level_{}".format(i)] = output
elif self.det_algorithm == "CT":
preds["maps"] = outputs[0]
preds["score"] = outputs[1]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]["points"]
if self.args.det_box_type == "poly":
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
if self.args.benchmark:
self.autolog.times.end(stamp=True)
et = time.time()
return dt_boxes, et - st
def __call__(self, img):
# For image like poster with one side much greater than the other side,
# splitting recursively and processing with overlap to enhance performance.
MIN_BOUND_DISTANCE = 50
dt_boxes = np.zeros((0, 4, 2), dtype=np.float32)
elapse = 0
if img.shape[0] / img.shape[1] > 2 and img.shape[0] > self.args.det_limit_side_len:
start_h = 0
end_h = 0
while end_h <= img.shape[0]:
end_h = start_h + img.shape[1] * 3 // 4
subimg = img[start_h:end_h, :]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_h
# To prevent text blocks from being cut off, roll back a certain buffer area.
if (
len(sub_dt_boxes) == 0
or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE
):
start_h = end_h
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
bottom_line = (
0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 2, 1]))
)
if bottom_line > 0:
start_h += bottom_line
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 2, 1] <= bottom_line]
else:
start_h = end_h
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
dt_boxes = sub_dt_boxes + np.array([0, offset], dtype=np.float32)
else:
dt_boxes = np.append(
dt_boxes,
sub_dt_boxes + np.array([0, offset], dtype=np.float32),
axis=0,
)
elapse += sub_elapse
elif img.shape[1] / img.shape[0] > 3 and img.shape[1] > self.args.det_limit_side_len * 3:
start_w = 0
end_w = 0
while end_w <= img.shape[1]:
end_w = start_w + img.shape[0] * 3 // 4
subimg = img[:, start_w:end_w]
if len(subimg) == 0:
break
sub_dt_boxes, sub_elapse = self.predict(subimg)
offset = start_w
if (
len(sub_dt_boxes) == 0
or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE
):
start_w = end_w
else:
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0])
sub_dt_boxes = sub_dt_boxes[sorted_indices]
right_line = (
0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 1, 0]))
)
if right_line > 0:
start_w += right_line
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line]
else:
start_w = end_w
if len(sub_dt_boxes) > 0:
if dt_boxes.shape[0] == 0:
dt_boxes = sub_dt_boxes + np.array([offset, 0], dtype=np.float32)
else:
dt_boxes = np.append(
dt_boxes,
sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
axis=0,
)
elapse += sub_elapse
else:
dt_boxes, elapse = self.predict(img)
return dt_boxes, elapse

View File

@@ -1,379 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import numpy as np
import math
import time
import utility
from utility import get_logger
from CTCLabelDecode import CTCLabelDecode
logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.postprocess_op = CTCLabelDecode(
character_dict_path=args.rec_char_dict_path, use_space_char=args.use_space_char
)
(
self.predictor,
self.input_tensor,
self.output_tensors,
self.config,
) = utility.create_predictor(args, "rec", logger)
self.benchmark = args.benchmark
self.use_onnx = args.use_onnx
self.return_word_box = args.return_word_box
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
if self.rec_algorithm == "ViTSTR":
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
else:
img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
if self.rec_algorithm == "ViTSTR":
norm_img = norm_img.astype(np.float32) / 255.0
else:
norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
return norm_img
elif self.rec_algorithm == "RFL":
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
resized_image = resized_image.astype("float32")
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
resized_image -= 0.5
resized_image /= 0.5
return resized_image
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
if self.use_onnx:
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
if self.rec_algorithm == "RARE":
if resized_w > self.rec_image_shape[2]:
resized_w = self.rec_image_shape[2]
imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0 : img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
gsrm_word_pos = (
np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype("int64")
)
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length]
)
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
"float32"
) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length]
)
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
"float32"
) * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos,
gsrm_word_pos,
gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2,
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[
encoder_word_pos,
gsrm_word_pos,
gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2,
] = self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (
norm_img,
encoder_word_pos,
gsrm_word_pos,
gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2,
)
def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype("float32")
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_cppd_padding(
self, img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR
):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(img, (imgW, imgH), interpolation=interpolation)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype("float32")
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype("float32")
resized_image = resized_image / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype("float32")
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.inverse:
img = 255 - img
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(
img,
((0, padding_h), (0, padding_w)),
"constant",
constant_values=(255),
)
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype("float32")
return img
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
if self.benchmark:
self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
wh_ratio_list = []
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.benchmark:
self.autolog.times.stamp()
assert self.use_onnx
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
rec_result = self.postprocess_op(
preds,
return_word_box=self.return_word_box,
wh_ratio_list=wh_ratio_list,
max_wh_ratio=max_wh_ratio,
)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
self.autolog.times.end(stamp=True)
return rec_res, time.time() - st

View File

@@ -1,689 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import functools
import logging
import cv2
import numpy as np
import PIL
from PIL import Image, ImageDraw, ImageFont
import math
import random
def str2bool(v):
return v.lower() in ("true", "yes", "t", "y", "1")
def str2int_tuple(v):
return tuple([int(i.strip()) for i in v.split(",")])
def init_args():
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--use_npu", type=str2bool, default=False)
parser.add_argument("--use_mlu", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
parser.add_argument("--gpu_id", type=int, default=0)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default="DB")
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default="max")
parser.add_argument("--det_box_type", type=str, default="quad")
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_scale", type=int, default=1)
# FCE parmas
parser.add_argument("--scales", type=list, default=[8, 16, 32])
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=1.0)
parser.add_argument("--fourier_degree", type=int, default=5)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument("--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for e2e
parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
parser.add_argument("--e2e_model_dir", type=str)
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
parser.add_argument("--e2e_limit_type", type=str, default="max")
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument("--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=["0", "180"])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=False)
# SR parmas
parser.add_argument("--sr_model_dir", type=str)
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
parser.add_argument("--sr_batch_num", type=int, default=1)
#
parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
# multi-process
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=str2bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)
# extended function
parser.add_argument(
"--return_word_box",
type=str2bool,
default=False,
help="Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery",
)
return parser
def parse_args():
parser = init_args()
return parser.parse_args([])
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
elif mode == "cls":
model_dir = args.cls_model_dir
elif mode == "rec":
model_dir = args.rec_model_dir
elif mode == "table":
model_dir = args.table_model_dir
elif mode == "ser":
model_dir = args.ser_model_dir
elif mode == "re":
model_dir = args.re_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
elif mode == "layout":
model_dir = args.layout_model_dir
else:
model_dir = args.e2e_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0)
assert args.use_onnx
import onnxruntime as ort
model_file_path = model_dir
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(model_file_path))
if args.use_gpu:
sess = ort.InferenceSession(model_file_path, providers=["CUDAExecutionProvider"])
else:
sess = ort.InferenceSession(model_file_path)
return sess, sess.get_inputs()[0], None, None
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet", "SVTR_HGNet"]:
output_name = "softmax_0.tmp_0"
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return output_tensors
def draw_e2e_res(dt_boxes, strs, img_path):
src_im = cv2.imread(img_path)
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1,
)
return src_im
def draw_text_det_res(dt_boxes, img):
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
return img
def resize_img(img, input_size=600):
"""
resize img and limit the longest side of the image to input_size
"""
img = np.array(img)
im_shape = img.shape
im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max)
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
return img
def draw_ocr(
image,
boxes,
txts=None,
scores=None,
drop_score=0.5,
font_path="./doc/fonts/simfang.ttf",
):
"""
Visualize the results of OCR detection and recognition
args:
image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
scores(list): txxs corresponding scores
drop_score(float): only scores greater than drop_threshold will be visualized
font_path: the path of font which is used to draw text
return(array):
the visualized img
"""
if scores is None:
scores = [1] * len(boxes)
box_num = len(boxes)
for i in range(box_num):
if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
continue
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
if txts is not None:
img = np.array(resize_img(image, input_size=600))
txt_img = text_visual(
txts,
scores,
img_h=img.shape[0],
img_w=600,
threshold=drop_score,
font_path=font_path,
)
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
return img
return image
def draw_ocr_box_txt(
image,
boxes,
txts=None,
scores=None,
drop_score=0.5,
font_path="./doc/fonts/simfang.ttf",
):
h, w = image.height, image.width
img_left = image.copy()
img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
if txts is None or len(txts) != len(boxes):
txts = [None] * len(boxes)
for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw_left.polygon(box, fill=color)
img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
pts = np.array(box, np.int32).reshape((-1, 1, 2))
cv2.polylines(img_right_text, [pts], True, color, 1)
img_right = cv2.bitwise_and(img_right, img_right_text)
img_left = Image.blend(image, img_left, 0.5)
img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
return np.array(img_show)
def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
box_height = int(math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2))
box_width = int(math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2))
if box_height > 2 * box_width and box_height > 30:
img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_height, box_width), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
img_text = img_text.transpose(Image.ROTATE_270)
else:
img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_width, box_height), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
pts2 = np.array(box, dtype=np.float32)
M = cv2.getPerspectiveTransform(pts1, pts2)
img_text = np.array(img_text, dtype=np.uint8)
img_right_text = cv2.warpPerspective(
img_text,
M,
img_size,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255),
)
return img_right_text
def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
font_size = int(sz[1] * 0.99)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
if int(PIL.__version__.split(".")[0]) < 10:
length = font.getsize(txt)[0]
else:
length = font.getlength(txt)
if length > sz[0]:
font_size = int(font_size * sz[0] / length)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
return font
def str_count(s):
"""
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
the number of Chinese characters
"""
import string
count_zh = count_pu = 0
s_len = len(s)
en_dg_count = 0
for c in s:
if c in string.ascii_letters or c.isdigit() or c.isspace():
en_dg_count += 1
elif c.isalpha():
count_zh += 1
else:
count_pu += 1
return s_len - math.ceil(en_dg_count / 2)
def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf"):
"""
create new blank img and draw txt on it
args:
texts(list): the text will be draw
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if scores is not None:
assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
def create_blank_img():
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
blank_img[:, img_w - 1 :] = 0
blank_img = Image.fromarray(blank_img).convert("RGB")
draw_txt = ImageDraw.Draw(blank_img)
return blank_img, draw_txt
blank_img, draw_txt = create_blank_img()
font_size = 20
txt_color = (0, 0, 0)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
gap = font_size + 5
txt_img_list = []
count, index = 1, 0
for idx, txt in enumerate(texts):
index += 1
if scores[idx] < threshold or math.isnan(scores[idx]):
index -= 1
continue
first_line = True
while str_count(txt) >= img_w // font_size - 4:
tmp = txt
txt = tmp[: img_w // font_size - 4]
if first_line:
new_txt = str(index) + ": " + txt
first_line = False
else:
new_txt = " " + txt
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
txt = tmp[img_w // font_size - 4 :]
if count >= img_h // gap - 1:
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
if first_line:
new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
else:
new_txt = " " + txt + " " + "%.3f" % (scores[idx])
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
# whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts):
txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img()
count = 0
count += 1
txt_img_list.append(np.array(blank_img))
if len(txt_img_list) == 1:
blank_img = np.array(txt_img_list[0])
else:
blank_img = np.concatenate(txt_img_list, axis=1)
return np.array(blank_img)
def base64_to_cv2(b64str):
import base64
data = base64.b64decode(b64str.encode("utf8"))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
if scores is None:
scores = [1] * len(boxes)
for box, score in zip(boxes, scores):
if score < drop_score:
continue
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
return image
def get_rotate_crop_image(img, points):
"""
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
"""
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))
)
img_crop_height = int(
max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))
)
pts_std = np.float32(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
)
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def get_minarea_rect_crop(img, points):
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_a, index_b, index_c, index_d = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_a = 0
index_d = 1
else:
index_a = 1
index_d = 0
if points[3][1] > points[2][1]:
index_b = 2
index_c = 3
else:
index_b = 3
index_c = 2
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
crop_img = get_rotate_crop_image(img, np.array(box))
return crop_img
# def check_gpu(use_gpu):
# if use_gpu and (
# not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
# ):
# use_gpu = False
# return use_gpu
def _check_image_file(path):
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file, infer_list=None):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
logger_initialized = {}
@functools.lru_cache()
def get_logger(name="ppocr", log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger_initialized[name] = True
logger.propagate = False
return logger
def get_rotate_crop_image(img, points):
"""
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
"""
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))
)
img_crop_height = int(
max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))
)
pts_std = np.float32(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
)
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def get_minarea_rect_crop(img, points):
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_a, index_b, index_c, index_d = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_a = 0
index_d = 1
else:
index_a = 1
index_d = 0
if points[3][1] > points[2][1]:
index_b = 2
index_c = 3
else:
index_b = 3
index_c = 2
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
crop_img = get_rotate_crop_image(img, np.array(box))
return crop_img
if __name__ == "__main__":
pass

View File

@@ -1,24 +0,0 @@
import os
from pathlib import Path
from datasets import load_dataset
from ..ocr_model.model.TexTeller import TexTeller
from ..globals import VOCAB_SIZE
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
tokenizer = TexTeller.get_tokenizer()
# Don't forget to config your dataset path in loader.py
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
new_tokenizer = tokenizer.train_new_from_iterator(
text_iterator=dataset['latex_formula'],
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
vocab_size=VOCAB_SIZE,
)
# Save the new tokenizer for later training and inference
new_tokenizer.save_pretrained('./your_dir_name')

View File

@@ -1 +0,0 @@
from .mix_inference import mix_inference

View File

@@ -1,261 +0,0 @@
import re
import heapq
import cv2
import time
import numpy as np
from collections import Counter
from typing import List
from PIL import Image
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if len(sorted_bboxes) == 0:
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
# log results
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while len(bboxes) > 0:
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None,
),
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
# log results
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y : y + h, x : x + w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def mix_inference(
img_path: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1,
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = latex_rec_predict(
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr
return md.strip()