[refactor] Init
3
texteller/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .texteller import TexTeller
|
||||
|
||||
__all__ = ['TexTeller']
|
||||
@@ -1,89 +0,0 @@
|
||||
import os
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = int(x)
|
||||
self.y = int(y)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Point(x={self.x}, y={self.y})"
|
||||
|
||||
|
||||
class Bbox:
|
||||
THREADHOLD = 0.4
|
||||
|
||||
def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None):
|
||||
self.p = Point(x, y)
|
||||
self.h = int(h)
|
||||
self.w = int(w)
|
||||
self.label = label
|
||||
self.confidence = confidence
|
||||
self.content = content
|
||||
|
||||
@property
|
||||
def ul_point(self) -> Point:
|
||||
return self.p
|
||||
|
||||
@property
|
||||
def ur_point(self) -> Point:
|
||||
return Point(self.p.x + self.w, self.p.y)
|
||||
|
||||
@property
|
||||
def ll_point(self) -> Point:
|
||||
return Point(self.p.x, self.p.y + self.h)
|
||||
|
||||
@property
|
||||
def lr_point(self) -> Point:
|
||||
return Point(self.p.x + self.w, self.p.y + self.h)
|
||||
|
||||
def same_row(self, other) -> bool:
|
||||
if (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) or (
|
||||
self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y
|
||||
):
|
||||
return True
|
||||
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
|
||||
return False
|
||||
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
'''
|
||||
from top to bottom, from left to right
|
||||
'''
|
||||
if not self.same_row(other):
|
||||
return self.p.y < other.p.y
|
||||
else:
|
||||
return self.p.x < other.p.x
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
||||
|
||||
|
||||
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
|
||||
curr_work_dir = Path(os.getcwd())
|
||||
log_dir = curr_work_dir / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
drawer = ImageDraw.Draw(img)
|
||||
for bbox in bboxes:
|
||||
# Calculate the coordinates for the rectangle to be drawn
|
||||
left = bbox.p.x
|
||||
top = bbox.p.y
|
||||
right = bbox.p.x + bbox.w
|
||||
bottom = bbox.p.y + bbox.h
|
||||
|
||||
# Draw the rectangle on the image
|
||||
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
||||
|
||||
# Optionally, add text label if it exists
|
||||
if bbox.label:
|
||||
drawer.text((left, top), bbox.label, fill="blue")
|
||||
|
||||
if bbox.content:
|
||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||
|
||||
# Save the image with drawn rectangles
|
||||
img.save(log_dir / name)
|
||||
@@ -1,226 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
from .preprocess import Compose
|
||||
from .Bbox import Bbox
|
||||
|
||||
|
||||
# Global dictionary
|
||||
SUPPORT_MODELS = {
|
||||
'YOLO',
|
||||
'PPYOLOE',
|
||||
'RCNN',
|
||||
'SSD',
|
||||
'Face',
|
||||
'FCOS',
|
||||
'SOLOv2',
|
||||
'TTFNet',
|
||||
'S2ANet',
|
||||
'JDE',
|
||||
'FairMOT',
|
||||
'DeepSORT',
|
||||
'GFL',
|
||||
'PicoDet',
|
||||
'CenterNet',
|
||||
'TOOD',
|
||||
'RetinaNet',
|
||||
'StrongBaseline',
|
||||
'STGCN',
|
||||
'YOLOX',
|
||||
'HRNet',
|
||||
'DETR',
|
||||
}
|
||||
|
||||
|
||||
class PredictConfig(object):
|
||||
"""set config of preprocess, postprocess and visualize
|
||||
Args:
|
||||
infer_config (str): path of infer_cfg.yml
|
||||
"""
|
||||
|
||||
def __init__(self, infer_config):
|
||||
# parsing Yaml config for Preprocess
|
||||
with open(infer_config) as f:
|
||||
yml_conf = yaml.safe_load(f)
|
||||
self.check_model(yml_conf)
|
||||
self.arch = yml_conf['arch']
|
||||
self.preprocess_infos = yml_conf['Preprocess']
|
||||
self.min_subgraph_size = yml_conf['min_subgraph_size']
|
||||
self.label_list = yml_conf['label_list']
|
||||
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
|
||||
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
|
||||
self.mask = yml_conf.get("mask", False)
|
||||
self.tracker = yml_conf.get("tracker", None)
|
||||
self.nms = yml_conf.get("NMS", None)
|
||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
||||
|
||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||
self.colors = {
|
||||
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
|
||||
}
|
||||
|
||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
|
||||
self.print_config()
|
||||
|
||||
def check_model(self, yml_conf):
|
||||
"""
|
||||
Raises:
|
||||
ValueError: loaded model not in supported model type
|
||||
"""
|
||||
for support_model in SUPPORT_MODELS:
|
||||
if support_model in yml_conf['arch']:
|
||||
return True
|
||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
|
||||
|
||||
def print_config(self):
|
||||
print('----------- Model Configuration -----------')
|
||||
print('%s: %s' % ('Model Arch', self.arch))
|
||||
print('%s: ' % ('Transform Order'))
|
||||
for op_info in self.preprocess_infos:
|
||||
print('--%s: %s' % ('transform op', op_info['type']))
|
||||
print('--------------------------------------------')
|
||||
|
||||
|
||||
def draw_bbox(image, outputs, infer_config):
|
||||
for output in outputs:
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
color = infer_config.colors[label]
|
||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
||||
cv2.putText(
|
||||
image,
|
||||
"{}: {:.2f}".format(label, score),
|
||||
(int(xmin), int(ymin - 5)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
2,
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
||||
# load preprocess transforms
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
errImgList = []
|
||||
|
||||
# Check and create subimg_save_dir if not exist
|
||||
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
|
||||
os.makedirs(subimg_save_dir, exist_ok=True)
|
||||
|
||||
first_image_skipped = False
|
||||
total_time = 0
|
||||
num_images = 0
|
||||
# predict image
|
||||
for img_path in tqdm(img_list):
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
||||
errImgList.append(img_path)
|
||||
continue
|
||||
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)
|
||||
|
||||
# Stop timing
|
||||
end_time = time.time()
|
||||
inference_time = end_time - start_time
|
||||
if not first_image_skipped:
|
||||
first_image_skipped = True
|
||||
else:
|
||||
total_time += inference_time
|
||||
num_images += 1
|
||||
print(
|
||||
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
|
||||
)
|
||||
|
||||
print("ONNXRuntime predict: ")
|
||||
if infer_config.arch in ["HRNet"]:
|
||||
print(np.array(outputs[0]))
|
||||
else:
|
||||
bboxes = np.array(outputs[0])
|
||||
for bbox in bboxes:
|
||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
||||
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
||||
|
||||
# Save the subimages (crop from the original image)
|
||||
subimg_counter = 1
|
||||
for output in np.array(outputs[0]):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
label = infer_config.label_list[int(cls_id)]
|
||||
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
|
||||
if len(subimg) == 0:
|
||||
continue
|
||||
|
||||
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
|
||||
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
|
||||
cv2.imwrite(subimg_path, subimg)
|
||||
subimg_counter += 1
|
||||
|
||||
# Draw bounding boxes and save the image with bounding boxes
|
||||
img_with_mask = img.copy()
|
||||
for output in np.array(outputs[0]):
|
||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
||||
if score > infer_config.draw_threshold:
|
||||
cv2.rectangle(
|
||||
img_with_mask,
|
||||
(int(xmin), int(ymin)),
|
||||
(int(xmax), int(ymax)),
|
||||
(255, 255, 255),
|
||||
-1,
|
||||
) # 盖白
|
||||
|
||||
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
||||
|
||||
output_dir = imgsave_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
draw_box_dir = os.path.join(output_dir, 'draw_box')
|
||||
mask_white_dir = os.path.join(output_dir, 'mask_white')
|
||||
os.makedirs(draw_box_dir, exist_ok=True)
|
||||
os.makedirs(mask_white_dir, exist_ok=True)
|
||||
|
||||
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
|
||||
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
|
||||
cv2.imwrite(output_file_mask, img_with_mask)
|
||||
cv2.imwrite(output_file_bbox, img_with_bbox)
|
||||
|
||||
avg_time_per_image = total_time / num_images if num_images > 0 else 0
|
||||
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
|
||||
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
|
||||
print("ErrorImgs:")
|
||||
print(errImgList)
|
||||
|
||||
|
||||
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||
res = []
|
||||
for output in outputs:
|
||||
cls_name = infer_config.label_list[int(output[0])]
|
||||
score = output[1]
|
||||
xmin = int(max(output[2], 0))
|
||||
ymin = int(max(output[3], 0))
|
||||
xmax = int(output[4])
|
||||
ymax = int(output[5])
|
||||
if score > infer_config.draw_threshold:
|
||||
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
||||
|
||||
return res
|
||||
@@ -1,27 +0,0 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: DETR
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 1600
|
||||
- 1600
|
||||
type: Resize
|
||||
- mean:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
norm_type: none
|
||||
std:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
label_list:
|
||||
- isolated
|
||||
- embedding
|
||||
@@ -1,485 +0,0 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
|
||||
|
||||
def decode_image(img_path):
|
||||
if isinstance(img_path, str):
|
||||
with open(img_path, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
else:
|
||||
assert isinstance(img_path, np.ndarray)
|
||||
data = img_path
|
||||
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
img_info = {
|
||||
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
||||
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
||||
}
|
||||
return im, img_info
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Permute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
class LetterBoxResize(object):
|
||||
def __init__(self, target_size):
|
||||
"""
|
||||
Resize image to target size, convert normalized xywh to pixel xyxy
|
||||
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
|
||||
Args:
|
||||
target_size (int|list): image target size.
|
||||
"""
|
||||
super(LetterBoxResize, self).__init__()
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
|
||||
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
|
||||
# letterbox: resize a rectangular image to a padded rectangular
|
||||
shape = img.shape[:2] # [height, width]
|
||||
ratio_h = float(height) / shape[0]
|
||||
ratio_w = float(width) / shape[1]
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
|
||||
padw = (width - new_shape[0]) / 2
|
||||
padh = (height - new_shape[1]) / 2
|
||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
||||
|
||||
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||||
) # padded rectangular
|
||||
return img, ratio, padw, padh
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
height, width = self.target_size
|
||||
h, w = im.shape[:2]
|
||||
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
|
||||
|
||||
new_shape = [round(h * ratio), round(w * ratio)]
|
||||
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
|
||||
"""
|
||||
Pad image to a specified size.
|
||||
Args:
|
||||
size (list[int]): image target size
|
||||
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
|
||||
"""
|
||||
super(Pad, self).__init__()
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
im_h, im_w = im.shape[:2]
|
||||
h, w = self.size
|
||||
if h == im_h and w == im_w:
|
||||
im = im.astype(np.float32)
|
||||
return im, im_info
|
||||
|
||||
canvas = np.ones((h, w, 3), dtype=np.float32)
|
||||
canvas *= np.array(self.fill_value, dtype=np.float32)
|
||||
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
|
||||
im = canvas
|
||||
return im, im_info
|
||||
|
||||
|
||||
def rotate_point(pt, angle_rad):
|
||||
"""Rotate a point by an angle.
|
||||
|
||||
Args:
|
||||
pt (list[float]): 2 dimensional point to be rotated
|
||||
angle_rad (float): rotation angle by radian
|
||||
|
||||
Returns:
|
||||
list[float]: Rotated point.
|
||||
"""
|
||||
assert len(pt) == 2
|
||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
||||
new_x = pt[0] * cs - pt[1] * sn
|
||||
new_y = pt[0] * sn + pt[1] * cs
|
||||
rotated_pt = [new_x, new_y]
|
||||
|
||||
return rotated_pt
|
||||
|
||||
|
||||
def _get_3rd_point(a, b):
|
||||
"""To calculate the affine matrix, three pairs of points are required. This
|
||||
function is used to get the 3rd point, given 2D points a & b.
|
||||
|
||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
||||
anticlockwise, using b as the rotation center.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): point(x,y)
|
||||
b (np.ndarray): point(x,y)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The 3rd point.
|
||||
"""
|
||||
assert len(a) == 2
|
||||
assert len(b) == 2
|
||||
direction = a - b
|
||||
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
|
||||
|
||||
return third_pt
|
||||
|
||||
|
||||
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
|
||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
||||
|
||||
Args:
|
||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
||||
wrt [width, height].
|
||||
rot (float): Rotation angle (degree).
|
||||
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
|
||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
||||
Default (0., 0.).
|
||||
inv (bool): Option to inverse the affine transform direction.
|
||||
(inv=False: src->dst or inv=True: dst->src)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The transform matrix.
|
||||
"""
|
||||
assert len(center) == 2
|
||||
assert len(output_size) == 2
|
||||
assert len(shift) == 2
|
||||
if not isinstance(input_size, (np.ndarray, list)):
|
||||
input_size = np.array([input_size, input_size], dtype=np.float32)
|
||||
scale_tmp = input_size
|
||||
|
||||
shift = np.array(shift)
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0.0, dst_w * -0.5])
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
class WarpAffine(object):
|
||||
"""Warp affine the image"""
|
||||
|
||||
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
|
||||
self.keep_res = keep_res
|
||||
self.pad = pad
|
||||
self.input_h = input_h
|
||||
self.input_w = input_w
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if self.keep_res:
|
||||
input_h = (h | self.pad) + 1
|
||||
input_w = (w | self.pad) + 1
|
||||
s = np.array([input_w, input_h], dtype=np.float32)
|
||||
c = np.array([w // 2, h // 2], dtype=np.float32)
|
||||
|
||||
else:
|
||||
s = max(h, w) * 1.0
|
||||
input_h, input_w = self.input_h, self.input_w
|
||||
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
|
||||
|
||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
||||
img = cv2.resize(img, (w, h))
|
||||
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||
return inp, im_info
|
||||
|
||||
|
||||
# keypoint preprocess
|
||||
def get_warp_matrix(theta, size_input, size_dst, size_target):
|
||||
"""This code is based on
|
||||
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
||||
|
||||
Calculate the transformation matrix under the constraint of unbiased.
|
||||
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
|
||||
Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Args:
|
||||
theta (float): Rotation angle in degrees.
|
||||
size_input (np.ndarray): Size of input image [w, h].
|
||||
size_dst (np.ndarray): Size of output image [w, h].
|
||||
size_target (np.ndarray): Size of ROI in input plane [w, h].
|
||||
|
||||
Returns:
|
||||
matrix (np.ndarray): A matrix for transformation.
|
||||
"""
|
||||
theta = np.deg2rad(theta)
|
||||
matrix = np.zeros((2, 3), dtype=np.float32)
|
||||
scale_x = size_dst[0] / size_target[0]
|
||||
scale_y = size_dst[1] / size_target[1]
|
||||
matrix[0, 0] = np.cos(theta) * scale_x
|
||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
||||
matrix[0, 2] = scale_x * (
|
||||
-0.5 * size_input[0] * np.cos(theta)
|
||||
+ 0.5 * size_input[1] * np.sin(theta)
|
||||
+ 0.5 * size_target[0]
|
||||
)
|
||||
matrix[1, 0] = np.sin(theta) * scale_y
|
||||
matrix[1, 1] = np.cos(theta) * scale_y
|
||||
matrix[1, 2] = scale_y * (
|
||||
-0.5 * size_input[0] * np.sin(theta)
|
||||
- 0.5 * size_input[1] * np.cos(theta)
|
||||
+ 0.5 * size_target[1]
|
||||
)
|
||||
return matrix
|
||||
|
||||
|
||||
class TopDownEvalAffine(object):
|
||||
"""apply affine transform to image and coords
|
||||
|
||||
Args:
|
||||
trainsize (list): [w, h], the standard size used to train
|
||||
use_udp (bool): whether to use Unbiased Data Processing.
|
||||
records(dict): the dict contained the image and coords
|
||||
|
||||
Returns:
|
||||
records (dict): contain the image and coords after tranformed
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, trainsize, use_udp=False):
|
||||
self.trainsize = trainsize
|
||||
self.use_udp = use_udp
|
||||
|
||||
def __call__(self, image, im_info):
|
||||
rot = 0
|
||||
imshape = im_info['im_shape'][::-1]
|
||||
center = im_info['center'] if 'center' in im_info else imshape / 2.0
|
||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
||||
if self.use_udp:
|
||||
trans = get_warp_matrix(
|
||||
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
|
||||
)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans,
|
||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
)
|
||||
else:
|
||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans,
|
||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
return image, im_info
|
||||
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = []
|
||||
for op_info in transforms:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
self.transforms.append(eval(op_type)(**new_op_info))
|
||||
|
||||
def __call__(self, img_path):
|
||||
img, im_info = decode_image(img_path)
|
||||
for t in self.transforms:
|
||||
img, im_info = t(img, im_info)
|
||||
inputs = copy.deepcopy(im_info)
|
||||
inputs['image'] = img
|
||||
return inputs
|
||||
@@ -1,23 +0,0 @@
|
||||
# Formula image(grayscale) mean and variance
|
||||
IMAGE_MEAN = 0.9545467
|
||||
IMAGE_STD = 0.15394445
|
||||
|
||||
# Vocabulary size for TexTeller
|
||||
VOCAB_SIZE = 15000
|
||||
|
||||
# Fixed size for input image for TexTeller
|
||||
FIXED_IMG_SIZE = 448
|
||||
|
||||
# Image channel for TexTeller
|
||||
IMG_CHANNELS = 1 # grayscale image
|
||||
|
||||
# Max size of token for embedding
|
||||
MAX_TOKEN_SIZE = 1024
|
||||
|
||||
# Scaling ratio for random resizing when training
|
||||
MAX_RESIZE_RATIO = 1.15
|
||||
MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# Minimum height and width for input image for TexTeller
|
||||
MIN_HEIGHT = 12
|
||||
MIN_WIDTH = 30
|
||||
@@ -1,43 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
||||
|
||||
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
REPO_NAME = 'OleehyO/TexTeller'
|
||||
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(
|
||||
Path(__file__).resolve().parent / "config.json"
|
||||
)
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
||||
if model_path is None or model_path == 'default':
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
use_gpu = True if onnx_provider == 'cuda' else False
|
||||
return ORTModelForVision2Seq.from_pretrained(
|
||||
cls.REPO_NAME,
|
||||
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
||||
)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_path is None or tokenizer_path == 'default':
|
||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
@@ -1,168 +0,0 @@
|
||||
{
|
||||
"_name_or_path": "OleehyO/TexTeller",
|
||||
"architectures": [
|
||||
"VisionEncoderDecoderModel"
|
||||
],
|
||||
"decoder": {
|
||||
"_name_or_path": "",
|
||||
"activation_dropout": 0.0,
|
||||
"activation_function": "gelu",
|
||||
"add_cross_attention": true,
|
||||
"architectures": null,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": 0,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"classifier_dropout": 0.0,
|
||||
"cross_attention_hidden_size": 768,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"decoder_start_token_id": 2,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 2,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_decoder": true,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layernorm_embedding": true,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 1024,
|
||||
"min_length": 0,
|
||||
"model_type": "trocr",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": 1,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"scale_embedding": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false,
|
||||
"use_cache": false,
|
||||
"use_learned_position_embeddings": true,
|
||||
"vocab_size": 15000
|
||||
},
|
||||
"encoder": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": false,
|
||||
"architectures": null,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"bad_words_ids": null,
|
||||
"begin_suppress_tokens": null,
|
||||
"bos_token_id": null,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": null,
|
||||
"decoder_start_token_id": null,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": false,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_stride": 16,
|
||||
"eos_token_id": null,
|
||||
"exponential_decay_length_penalty": null,
|
||||
"finetuning_task": null,
|
||||
"forced_bos_token_id": null,
|
||||
"forced_eos_token_id": null,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1"
|
||||
},
|
||||
"image_size": 448,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"is_decoder": false,
|
||||
"is_encoder_decoder": false,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"min_length": 0,
|
||||
"model_type": "vit",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 12,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_channels": 1,
|
||||
"num_hidden_layers": 12,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"pad_token_id": null,
|
||||
"patch_size": 16,
|
||||
"prefix": null,
|
||||
"problem_type": null,
|
||||
"pruned_heads": {},
|
||||
"qkv_bias": false,
|
||||
"remove_invalid_values": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": true,
|
||||
"return_dict_in_generate": false,
|
||||
"sep_token_id": null,
|
||||
"suppress_tokens": null,
|
||||
"task_specific_params": null,
|
||||
"temperature": 1.0,
|
||||
"tf_legacy_loss": false,
|
||||
"tie_encoder_decoder": false,
|
||||
"tie_word_embeddings": true,
|
||||
"tokenizer_class": null,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": null,
|
||||
"torchscript": false,
|
||||
"typical_p": 1.0,
|
||||
"use_bfloat16": false
|
||||
},
|
||||
"is_encoder_decoder": true,
|
||||
"model_type": "vision-encoder-decoder",
|
||||
"tie_word_embeddings": false,
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true
|
||||
}
|
||||
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 3.7 KiB |
|
Before Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
@@ -1,35 +0,0 @@
|
||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||
@@ -1,114 +0,0 @@
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import (
|
||||
tokenize_fn,
|
||||
collate_fn,
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
filter_fn,
|
||||
)
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
eval_config['generation_config'] = generate_config
|
||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model,
|
||||
seq2seq_config,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
eval_res = trainer.evaluate()
|
||||
print(eval_res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
||||
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
||||
dataset = dataset.filter(
|
||||
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
||||
)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(
|
||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||
)
|
||||
|
||||
# Split dataset into train and eval, ratio 9:1
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
|
||||
# Train from scratch
|
||||
model = TexTeller()
|
||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
||||
|
||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||
# +e.g.
|
||||
# +model = TexTeller.from_pretrained(
|
||||
# + '/path/to/your/model_checkpoint'
|
||||
# +)
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate and len(eval_dataset) > 0:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
@@ -1,31 +0,0 @@
|
||||
CONFIG = {
|
||||
"seed": 42, # Random seed for reproducibility
|
||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||
"learning_rate": 5e-5, # Learning rate
|
||||
"num_train_epochs": 10, # Total number of training epochs
|
||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
||||
"output_dir": "train_result", # Output directory
|
||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
||||
"save_strategy": "steps", # Strategy to save checkpoints
|
||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||
"logging_strategy": "steps", # Log every certain number of steps
|
||||
"logging_steps": 500, # Number of steps between each log
|
||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
||||
"optim": "adamw_torch", # Optimizer
|
||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
||||
"eval_steps": 500, # If evaluation_strategy="step"
|
||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from .transforms import train_transform, inference_transform
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = inference_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT
|
||||
and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
processed_images = []
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
if image.dtype == np.uint16:
|
||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
||||
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
@@ -1,116 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria
|
||||
from typing import List, Union
|
||||
|
||||
from .transforms import inference_transform
|
||||
from .helpers import convert2rgb
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops generation efficiently if any n-gram repeats.
|
||||
|
||||
This criteria maintains a set of encountered n-grams.
|
||||
At each step, it checks if the *latest* n-gram is already in the set.
|
||||
If yes, it stops generation. If no, it adds the n-gram to the set.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int):
|
||||
"""
|
||||
Args:
|
||||
n (int): The size of the n-gram to check for repetition.
|
||||
"""
|
||||
if n <= 0:
|
||||
raise ValueError("n-gram size 'n' must be positive.")
|
||||
self.n = n
|
||||
# Stores tuples of token IDs representing seen n-grams
|
||||
self.seen_ngrams = set()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores.
|
||||
|
||||
Return:
|
||||
`bool`: `True` if generation should stop, `False` otherwise.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
# Need at least n tokens to form the first n-gram
|
||||
if seq_length < self.n:
|
||||
return False
|
||||
|
||||
# --- Efficient Check ---
|
||||
# Consider only the first sequence in the batch for simplicity
|
||||
if batch_size > 1:
|
||||
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
|
||||
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
|
||||
# For now, we'll focus on the first sequence.
|
||||
pass # No warning needed every step, maybe once in __init__ if needed.
|
||||
|
||||
sequence = input_ids[0] # Get the first sequence
|
||||
|
||||
# Get the latest n-gram (the one ending at the last token)
|
||||
last_ngram_tensor = sequence[-self.n :]
|
||||
# Convert to a hashable tuple for set storage and lookup
|
||||
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
|
||||
|
||||
# Check if this n-gram has been seen before *at any prior step*
|
||||
if last_ngram_tuple in self.seen_ngrams:
|
||||
return True # Stop generation
|
||||
else:
|
||||
# It's a new n-gram, add it to the set and continue
|
||||
self.seen_ngrams.add(last_ngram_tuple)
|
||||
return False # Continue generation
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs: Union[List[str], List[np.ndarray]],
|
||||
accelerator: str = 'cpu',
|
||||
num_beams: int = 1,
|
||||
max_tokens=None,
|
||||
) -> List[str]:
|
||||
if imgs == []:
|
||||
return []
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, turn model.eval()
|
||||
model.eval()
|
||||
if isinstance(imgs[0], str):
|
||||
imgs = convert2rgb(imgs)
|
||||
else: # already numpy array(rgb format)
|
||||
assert isinstance(imgs[0], np.ndarray)
|
||||
imgs = imgs
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, move weights to device
|
||||
model = model.to(accelerator)
|
||||
pixel_values = pixel_values.to(accelerator)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
# no_repeat_ngram_size=10,
|
||||
)
|
||||
pred = model.generate(
|
||||
pixel_values.to(model.device),
|
||||
generation_config=generate_config,
|
||||
# stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)],
|
||||
)
|
||||
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
@@ -1,698 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Python implementation of tex-fmt, a LaTeX formatter.
|
||||
Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
|
||||
"""
|
||||
|
||||
import re
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Dict, Set
|
||||
|
||||
# Constants
|
||||
LINE_END = "\n"
|
||||
ITEM = "\\item"
|
||||
DOC_BEGIN = "\\begin{document}"
|
||||
DOC_END = "\\end{document}"
|
||||
ENV_BEGIN = "\\begin{"
|
||||
ENV_END = "\\end{"
|
||||
TEXT_LINE_START = ""
|
||||
COMMENT_LINE_START = "% "
|
||||
|
||||
# Opening and closing delimiters
|
||||
OPENS = ['{', '(', '[']
|
||||
CLOSES = ['}', ')', ']']
|
||||
|
||||
# Names of LaTeX verbatim environments
|
||||
VERBATIMS = ["verbatim", "Verbatim", "lstlisting", "minted", "comment"]
|
||||
VERBATIMS_BEGIN = [f"\\begin{{{v}}}" for v in VERBATIMS]
|
||||
VERBATIMS_END = [f"\\end{{{v}}}" for v in VERBATIMS]
|
||||
|
||||
# Regex patterns for sectioning commands
|
||||
SPLITTING = [
|
||||
r"\\begin\{",
|
||||
r"\\end\{",
|
||||
r"\\item(?:$|[^a-zA-Z])",
|
||||
r"\\(?:sub){0,2}section\*?\{",
|
||||
r"\\chapter\*?\{",
|
||||
r"\\part\*?\{",
|
||||
]
|
||||
|
||||
# Compiled regexes
|
||||
SPLITTING_STRING = f"({'|'.join(SPLITTING)})"
|
||||
RE_NEWLINES = re.compile(f"{LINE_END}{LINE_END}({LINE_END})+")
|
||||
RE_TRAIL = re.compile(f" +{LINE_END}")
|
||||
RE_SPLITTING = re.compile(SPLITTING_STRING)
|
||||
RE_SPLITTING_SHARED_LINE = re.compile(f"(?:\\S.*?)(?:{SPLITTING_STRING}.*)")
|
||||
RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTING_STRING}.*)")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
"""Command line arguments and configuration."""
|
||||
|
||||
tabchar: str = " "
|
||||
tabsize: int = 4
|
||||
wrap: bool = False
|
||||
wraplen: int = 80
|
||||
wrapmin: int = 40
|
||||
lists: List[str] = None
|
||||
verbosity: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lists is None:
|
||||
self.lists = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ignore:
|
||||
"""Information on the ignored state of a line."""
|
||||
|
||||
actual: bool = False
|
||||
visual: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(False, False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verbatim:
|
||||
"""Information on the verbatim state of a line."""
|
||||
|
||||
actual: int = 0
|
||||
visual: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(0, False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Indent:
|
||||
"""Information on the indentation state of a line."""
|
||||
|
||||
actual: int = 0
|
||||
visual: int = 0
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(0, 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Information on the current state during formatting."""
|
||||
|
||||
linum_old: int = 1
|
||||
linum_new: int = 1
|
||||
ignore: Ignore = None
|
||||
indent: Indent = None
|
||||
verbatim: Verbatim = None
|
||||
linum_last_zero_indent: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ignore is None:
|
||||
self.ignore = Ignore.new()
|
||||
if self.indent is None:
|
||||
self.indent = Indent.new()
|
||||
if self.verbatim is None:
|
||||
self.verbatim = Verbatim.new()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pattern:
|
||||
"""Record whether a line contains certain patterns."""
|
||||
|
||||
contains_env_begin: bool = False
|
||||
contains_env_end: bool = False
|
||||
contains_item: bool = False
|
||||
contains_splitting: bool = False
|
||||
contains_comment: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls, s: str):
|
||||
"""Check if a string contains patterns."""
|
||||
if RE_SPLITTING.search(s):
|
||||
return cls(
|
||||
contains_env_begin=ENV_BEGIN in s,
|
||||
contains_env_end=ENV_END in s,
|
||||
contains_item=ITEM in s,
|
||||
contains_splitting=True,
|
||||
contains_comment='%' in s,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
contains_env_begin=False,
|
||||
contains_env_end=False,
|
||||
contains_item=False,
|
||||
contains_splitting=False,
|
||||
contains_comment='%' in s,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Log:
|
||||
"""Log message."""
|
||||
|
||||
level: str
|
||||
file: str
|
||||
message: str
|
||||
linum_new: Optional[int] = None
|
||||
linum_old: Optional[int] = None
|
||||
line: Optional[str] = None
|
||||
|
||||
|
||||
def find_comment_index(line: str, pattern: Pattern) -> Optional[int]:
|
||||
"""Find the index of a comment in a line."""
|
||||
if not pattern.contains_comment:
|
||||
return None
|
||||
|
||||
in_command = False
|
||||
for i, c in enumerate(line):
|
||||
if c == '\\':
|
||||
in_command = True
|
||||
elif in_command and not c.isalpha():
|
||||
in_command = False
|
||||
elif c == '%' and not in_command:
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def contains_ignore_skip(line: str) -> bool:
|
||||
"""Check if a line contains a skip directive."""
|
||||
return line.endswith("% tex-fmt: skip")
|
||||
|
||||
|
||||
def contains_ignore_begin(line: str) -> bool:
|
||||
"""Check if a line contains the start of an ignore block."""
|
||||
return line.endswith("% tex-fmt: off")
|
||||
|
||||
|
||||
def contains_ignore_end(line: str) -> bool:
|
||||
"""Check if a line contains the end of an ignore block."""
|
||||
return line.endswith("% tex-fmt: on")
|
||||
|
||||
|
||||
def get_ignore(line: str, state: State, logs: List[Log], file: str, warn: bool) -> Ignore:
|
||||
"""Determine whether a line should be ignored."""
|
||||
skip = contains_ignore_skip(line)
|
||||
begin = contains_ignore_begin(line)
|
||||
end = contains_ignore_end(line)
|
||||
|
||||
if skip:
|
||||
actual = state.ignore.actual
|
||||
visual = True
|
||||
elif begin:
|
||||
actual = True
|
||||
visual = True
|
||||
if warn and state.ignore.actual:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Cannot begin ignore block:",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
elif end:
|
||||
actual = False
|
||||
visual = True
|
||||
if warn and not state.ignore.actual:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="No ignore block to end.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
else:
|
||||
actual = state.ignore.actual
|
||||
visual = state.ignore.actual
|
||||
|
||||
return Ignore(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def get_verbatim_diff(line: str, pattern: Pattern) -> int:
|
||||
"""Calculate total verbatim depth change."""
|
||||
if pattern.contains_env_begin and any(r in line for r in VERBATIMS_BEGIN):
|
||||
return 1
|
||||
elif pattern.contains_env_end and any(r in line for r in VERBATIMS_END):
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_verbatim(
|
||||
line: str, state: State, logs: List[Log], file: str, warn: bool, pattern: Pattern
|
||||
) -> Verbatim:
|
||||
"""Determine whether a line is in a verbatim environment."""
|
||||
diff = get_verbatim_diff(line, pattern)
|
||||
actual = state.verbatim.actual + diff
|
||||
visual = actual > 0 or state.verbatim.actual > 0
|
||||
|
||||
if warn and actual < 0:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Verbatim count is negative.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
return Verbatim(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def get_diff(line: str, pattern: Pattern, lists_begin: List[str], lists_end: List[str]) -> int:
|
||||
"""Calculate total indentation change due to the current line."""
|
||||
diff = 0
|
||||
|
||||
# Other environments get single indents
|
||||
if pattern.contains_env_begin and ENV_BEGIN in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_BEGIN in line:
|
||||
return 0
|
||||
diff += 1
|
||||
diff += 1 if any(r in line for r in lists_begin) else 0
|
||||
elif pattern.contains_env_end and ENV_END in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_END in line:
|
||||
return 0
|
||||
diff -= 1
|
||||
diff -= 1 if any(r in line for r in lists_end) else 0
|
||||
|
||||
# Indent for delimiters
|
||||
for c in line:
|
||||
if c in OPENS:
|
||||
diff += 1
|
||||
elif c in CLOSES:
|
||||
diff -= 1
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def get_back(line: str, pattern: Pattern, state: State, lists_end: List[str]) -> int:
|
||||
"""Calculate dedentation for the current line."""
|
||||
# Only need to dedent if indentation is present
|
||||
if state.indent.actual == 0:
|
||||
return 0
|
||||
|
||||
if pattern.contains_env_end and ENV_END in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_END in line:
|
||||
return 0
|
||||
# List environments get double indents for indenting items
|
||||
for r in lists_end:
|
||||
if r in line:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
# Items get dedented
|
||||
if pattern.contains_item and ITEM in line:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def get_indent(
|
||||
line: str,
|
||||
prev_indent: Indent,
|
||||
pattern: Pattern,
|
||||
state: State,
|
||||
lists_begin: List[str],
|
||||
lists_end: List[str],
|
||||
) -> Indent:
|
||||
"""Calculate the indent for a line."""
|
||||
diff = get_diff(line, pattern, lists_begin, lists_end)
|
||||
back = get_back(line, pattern, state, lists_end)
|
||||
|
||||
actual = prev_indent.actual + diff
|
||||
visual = max(0, prev_indent.actual - back)
|
||||
|
||||
return Indent(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def calculate_indent(
|
||||
line: str,
|
||||
state: State,
|
||||
logs: List[Log],
|
||||
file: str,
|
||||
args: Args,
|
||||
pattern: Pattern,
|
||||
lists_begin: List[str],
|
||||
lists_end: List[str],
|
||||
) -> Indent:
|
||||
"""Calculate the indent for a line and update the state."""
|
||||
indent = get_indent(line, state.indent, pattern, state, lists_begin, lists_end)
|
||||
|
||||
# Update the state
|
||||
state.indent = indent
|
||||
|
||||
# Record the last line with zero indent
|
||||
if indent.visual == 0:
|
||||
state.linum_last_zero_indent = state.linum_new
|
||||
|
||||
return indent
|
||||
|
||||
|
||||
def apply_indent(line: str, indent: Indent, args: Args, indent_char: str) -> str:
|
||||
"""Apply indentation to a line."""
|
||||
if not line.strip():
|
||||
return ""
|
||||
|
||||
indent_str = indent_char * (indent.visual * args.tabsize)
|
||||
return indent_str + line.lstrip()
|
||||
|
||||
|
||||
def needs_wrap(line: str, indent_length: int, args: Args) -> bool:
|
||||
"""Check if a line needs wrapping."""
|
||||
return args.wrap and (len(line) + indent_length > args.wraplen)
|
||||
|
||||
|
||||
def find_wrap_point(line: str, indent_length: int, args: Args) -> Optional[int]:
|
||||
"""Find the best place to break a long line."""
|
||||
wrap_point = None
|
||||
after_char = False
|
||||
prev_char = None
|
||||
|
||||
line_width = 0
|
||||
wrap_boundary = args.wrapmin - indent_length
|
||||
|
||||
for i, c in enumerate(line):
|
||||
line_width += 1
|
||||
if line_width > wrap_boundary and wrap_point is not None:
|
||||
break
|
||||
if c == ' ' and prev_char != '\\':
|
||||
if after_char:
|
||||
wrap_point = i
|
||||
elif c != '%':
|
||||
after_char = True
|
||||
prev_char = c
|
||||
|
||||
return wrap_point
|
||||
|
||||
|
||||
def apply_wrap(
|
||||
line: str,
|
||||
indent_length: int,
|
||||
state: State,
|
||||
file: str,
|
||||
args: Args,
|
||||
logs: List[Log],
|
||||
pattern: Pattern,
|
||||
) -> Optional[List[str]]:
|
||||
"""Wrap a long line into a short prefix and a suffix."""
|
||||
if args.verbosity >= 3: # Trace level
|
||||
logs.append(
|
||||
Log(
|
||||
level="TRACE",
|
||||
file=file,
|
||||
message="Wrapping long line.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
wrap_point = find_wrap_point(line, indent_length, args)
|
||||
comment_index = find_comment_index(line, pattern)
|
||||
|
||||
if wrap_point is None or wrap_point > args.wraplen:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Line cannot be wrapped.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
this_line = line[:wrap_point]
|
||||
|
||||
if comment_index is not None and wrap_point > comment_index:
|
||||
next_line_start = COMMENT_LINE_START
|
||||
else:
|
||||
next_line_start = TEXT_LINE_START
|
||||
|
||||
next_line = line[wrap_point + 1 :]
|
||||
|
||||
return [this_line, next_line_start, next_line]
|
||||
|
||||
|
||||
def needs_split(line: str, pattern: Pattern) -> bool:
|
||||
"""Check if line contains content which should be split onto a new line."""
|
||||
# Check if we should format this line and if we've matched an environment
|
||||
contains_splittable_env = (
|
||||
pattern.contains_splitting and RE_SPLITTING_SHARED_LINE.search(line) is not None
|
||||
)
|
||||
|
||||
# If we're not ignoring and we've matched an environment...
|
||||
if contains_splittable_env:
|
||||
# Return True if the comment index is None (which implies the split point must be in text),
|
||||
# otherwise compare the index of the comment with the split point
|
||||
comment_index = find_comment_index(line, pattern)
|
||||
if comment_index is None:
|
||||
return True
|
||||
|
||||
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
|
||||
if match and match.start(2) > comment_index:
|
||||
# If split point is past the comment index, don't split
|
||||
return False
|
||||
else:
|
||||
# Otherwise, split point is before comment and we do split
|
||||
return True
|
||||
else:
|
||||
# If ignoring or didn't match an environment, don't need a new line
|
||||
return False
|
||||
|
||||
|
||||
def split_line(line: str, state: State, file: str, args: Args, logs: List[Log]) -> Tuple[str, str]:
|
||||
"""Ensure lines are split correctly."""
|
||||
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
|
||||
if not match:
|
||||
return line, ""
|
||||
|
||||
prev = match.group('prev')
|
||||
rest = match.group('env')
|
||||
|
||||
if args.verbosity >= 3: # Trace level
|
||||
logs.append(
|
||||
Log(
|
||||
level="TRACE",
|
||||
file=file,
|
||||
message="Placing environment on new line.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
return prev, rest
|
||||
|
||||
|
||||
def set_ignore_and_report(
|
||||
line: str, temp_state: State, logs: List[Log], file: str, pattern: Pattern
|
||||
) -> bool:
|
||||
"""Sets the ignore and verbatim flags in the given State based on line and returns whether line should be ignored."""
|
||||
temp_state.ignore = get_ignore(line, temp_state, logs, file, True)
|
||||
temp_state.verbatim = get_verbatim(line, temp_state, logs, file, True, pattern)
|
||||
|
||||
return temp_state.verbatim.visual or temp_state.ignore.visual
|
||||
|
||||
|
||||
def clean_text(text: str, args: Args) -> str:
|
||||
"""Cleans the given text by removing extra line breaks and trailing spaces."""
|
||||
# Remove extra newlines
|
||||
text = RE_NEWLINES.sub(f"{LINE_END}{LINE_END}", text)
|
||||
|
||||
# Remove tabs if they shouldn't be used
|
||||
if args.tabchar != '\t':
|
||||
text = text.replace('\t', ' ' * args.tabsize)
|
||||
|
||||
# Remove trailing spaces
|
||||
text = RE_TRAIL.sub(LINE_END, text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def remove_trailing_spaces(text: str) -> str:
|
||||
"""Remove trailing spaces from line endings."""
|
||||
return RE_TRAIL.sub(LINE_END, text)
|
||||
|
||||
|
||||
def remove_trailing_blank_lines(text: str) -> str:
|
||||
"""Remove trailing blank lines from file."""
|
||||
return text.rstrip() + LINE_END
|
||||
|
||||
|
||||
def indents_return_to_zero(state: State) -> bool:
|
||||
"""Check if indentation returns to zero at the end of the file."""
|
||||
return state.indent.actual == 0
|
||||
|
||||
|
||||
def format_latex(
|
||||
old_text: str, file: str = "input.tex", args: Optional[Args] = None
|
||||
) -> Tuple[str, List[Log]]:
|
||||
"""Central function to format a LaTeX string."""
|
||||
if args is None:
|
||||
args = Args()
|
||||
|
||||
logs = []
|
||||
logs.append(Log(level="INFO", file=file, message="Formatting started."))
|
||||
|
||||
# Clean the source file
|
||||
old_text = clean_text(old_text, args)
|
||||
old_lines = list(enumerate(old_text.splitlines(), 1))
|
||||
|
||||
# Initialize
|
||||
state = State()
|
||||
queue = []
|
||||
new_text = ""
|
||||
|
||||
# Select the character used for indentation
|
||||
indent_char = '\t' if args.tabchar == '\t' else ' '
|
||||
|
||||
# Get any extra environments to be indented as lists
|
||||
lists_begin = [f"\\begin{{{l}}}" for l in args.lists]
|
||||
lists_end = [f"\\end{{{l}}}" for l in args.lists]
|
||||
|
||||
while True:
|
||||
if queue:
|
||||
linum_old, line = queue.pop(0)
|
||||
|
||||
# Read the patterns present on this line
|
||||
pattern = Pattern.new(line)
|
||||
|
||||
# Temporary state for working on this line
|
||||
temp_state = State(
|
||||
linum_old=linum_old,
|
||||
linum_new=state.linum_new,
|
||||
ignore=Ignore(state.ignore.actual, state.ignore.visual),
|
||||
indent=Indent(state.indent.actual, state.indent.visual),
|
||||
verbatim=Verbatim(state.verbatim.actual, state.verbatim.visual),
|
||||
linum_last_zero_indent=state.linum_last_zero_indent,
|
||||
)
|
||||
|
||||
# If the line should not be ignored...
|
||||
if not set_ignore_and_report(line, temp_state, logs, file, pattern):
|
||||
# Check if the line should be split because of a pattern that should begin on a new line
|
||||
if needs_split(line, pattern):
|
||||
# Split the line into two...
|
||||
this_line, next_line = split_line(line, temp_state, file, args, logs)
|
||||
# ...and queue the second part for formatting
|
||||
if next_line:
|
||||
queue.insert(0, (linum_old, next_line))
|
||||
line = this_line
|
||||
|
||||
# Calculate the indent based on the current state and the patterns in the line
|
||||
indent = calculate_indent(
|
||||
line, temp_state, logs, file, args, pattern, lists_begin, lists_end
|
||||
)
|
||||
|
||||
indent_length = indent.visual * args.tabsize
|
||||
|
||||
# Wrap the line before applying the indent, and loop back if the line needed wrapping
|
||||
if needs_wrap(line.lstrip(), indent_length, args):
|
||||
wrapped_lines = apply_wrap(
|
||||
line.lstrip(), indent_length, temp_state, file, args, logs, pattern
|
||||
)
|
||||
if wrapped_lines:
|
||||
this_line, next_line_start, next_line = wrapped_lines
|
||||
queue.insert(0, (linum_old, next_line_start + next_line))
|
||||
queue.insert(0, (linum_old, this_line))
|
||||
continue
|
||||
|
||||
# Lastly, apply the indent if the line didn't need wrapping
|
||||
line = apply_indent(line, indent, args, indent_char)
|
||||
|
||||
# Add line to new text
|
||||
state = temp_state
|
||||
new_text += line + LINE_END
|
||||
state.linum_new += 1
|
||||
elif old_lines:
|
||||
linum_old, line = old_lines.pop(0)
|
||||
queue.append((linum_old, line))
|
||||
else:
|
||||
break
|
||||
|
||||
if not indents_return_to_zero(state):
|
||||
msg = f"Indent does not return to zero. Last non-indented line is line {state.linum_last_zero_indent}"
|
||||
logs.append(Log(level="WARN", file=file, message=msg))
|
||||
|
||||
new_text = remove_trailing_spaces(new_text)
|
||||
new_text = remove_trailing_blank_lines(new_text)
|
||||
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
|
||||
|
||||
return new_text, logs
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line entry point."""
|
||||
parser = argparse.ArgumentParser(description="Format LaTeX files")
|
||||
parser.add_argument("file", help="LaTeX file to format")
|
||||
parser.add_argument(
|
||||
"--tabchar",
|
||||
choices=["space", "tab"],
|
||||
default="space",
|
||||
help="Character to use for indentation",
|
||||
)
|
||||
parser.add_argument("--tabsize", type=int, default=4, help="Number of spaces per indent level")
|
||||
parser.add_argument("--wrap", action="store_true", help="Enable line wrapping")
|
||||
parser.add_argument("--wraplen", type=int, default=80, help="Maximum line length")
|
||||
parser.add_argument(
|
||||
"--wrapmin", type=int, default=40, help="Minimum line length before wrapping"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lists", nargs="+", default=[], help="Additional environments to indent as lists"
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="count", default=0, help="Increase verbosity")
|
||||
parser.add_argument("--output", "-o", help="Output file (default: overwrite input)")
|
||||
|
||||
args_parsed = parser.parse_args()
|
||||
|
||||
# Convert command line args to our Args class
|
||||
args = Args(
|
||||
tabchar="\t" if args_parsed.tabchar == "tab" else " ",
|
||||
tabsize=args_parsed.tabsize,
|
||||
wrap=args_parsed.wrap,
|
||||
wraplen=args_parsed.wraplen,
|
||||
wrapmin=args_parsed.wrapmin,
|
||||
lists=args_parsed.lists,
|
||||
verbosity=args_parsed.verbose,
|
||||
)
|
||||
|
||||
# Read input file
|
||||
with open(args_parsed.file, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
# Format the text
|
||||
formatted_text, logs = format_latex(text, args_parsed.file, args)
|
||||
|
||||
# Print logs if verbose
|
||||
if args.verbosity > 0:
|
||||
for log in logs:
|
||||
if log.linum_new is not None:
|
||||
print(f"{log.level} {log.file}:{log.linum_new}:{log.linum_old}: {log.message}")
|
||||
else:
|
||||
print(f"{log.level} {log.file}: {log.message}")
|
||||
|
||||
# Write output
|
||||
output_file = args_parsed.output or args_parsed.file
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(formatted_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,25 +0,0 @@
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from transformers import EvalPrediction, RobertaTokenizer
|
||||
|
||||
|
||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
||||
cur_dir = Path(os.getcwd())
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
metric = evaluate.load(
|
||||
'google_bleu'
|
||||
) # Will download the metric from huggingface if not already downloaded
|
||||
os.chdir(cur_dir)
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
|
||||
labels = np.where(labels == -100, 1, labels)
|
||||
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
@@ -1,152 +0,0 @@
|
||||
from augraphy import *
|
||||
import random
|
||||
|
||||
|
||||
def ocr_augmentation_pipeline():
|
||||
pre_phase = []
|
||||
|
||||
ink_phase = [
|
||||
InkColorSwap(
|
||||
ink_swap_color="random",
|
||||
ink_swap_sequence_number_range=(5, 10),
|
||||
ink_swap_min_width_range=(2, 3),
|
||||
ink_swap_max_width_range=(100, 120),
|
||||
ink_swap_min_height_range=(2, 3),
|
||||
ink_swap_max_height_range=(100, 120),
|
||||
ink_swap_min_area_range=(10, 20),
|
||||
ink_swap_max_area_range=(400, 500),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
LinesDegradation(
|
||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||
line_gradient_range=(32, 255),
|
||||
line_gradient_direction=(0, 2),
|
||||
line_split_probability=(0.2, 0.4),
|
||||
line_replacement_value=(250, 255),
|
||||
line_min_length=(30, 40),
|
||||
line_long_to_short_ratio=(5, 7),
|
||||
line_replacement_probability=(0.4, 0.5),
|
||||
line_replacement_thickness=(1, 3),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
OneOf(
|
||||
[
|
||||
Dithering(
|
||||
dither="floyd-steinberg",
|
||||
order=(3, 5),
|
||||
),
|
||||
InkBleed(
|
||||
intensity_range=(0.1, 0.2),
|
||||
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
||||
severity=(0.4, 0.6),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
# ============================
|
||||
InkShifter(
|
||||
text_shift_scale_range=(18, 27),
|
||||
text_shift_factor_range=(1, 4),
|
||||
text_fade_range=(0, 2),
|
||||
blur_kernel_size=(5, 5),
|
||||
blur_sigma=0,
|
||||
noise_type="perlin",
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# ============================
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
NoiseTexturize( # tested
|
||||
sigma_range=(3, 10),
|
||||
turbulence_range=(2, 5),
|
||||
texture_width_range=(300, 500),
|
||||
texture_height_range=(300, 500),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
BrightnessTexturize( # tested
|
||||
texturize_range=(0.9, 0.99),
|
||||
deviation=0.03,
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
]
|
||||
|
||||
post_phase = [
|
||||
ColorShift( # tested
|
||||
color_shift_offset_x_range=(3, 5),
|
||||
color_shift_offset_y_range=(3, 5),
|
||||
color_shift_iterations=(2, 3),
|
||||
color_shift_brightness_range=(0.9, 1.1),
|
||||
color_shift_gaussian_kernel_range=(3, 3),
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
DirtyDrum( # tested
|
||||
line_width_range=(1, 6),
|
||||
line_concentration=random.uniform(0.05, 0.15),
|
||||
direction=random.randint(0, 2),
|
||||
noise_intensity=random.uniform(0.6, 0.95),
|
||||
noise_value=(64, 224),
|
||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||
sigmaX=0,
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
LightingGradient(
|
||||
light_position=None,
|
||||
direction=None,
|
||||
max_brightness=255,
|
||||
min_brightness=0,
|
||||
mode="gaussian",
|
||||
linear_decay_rate=None,
|
||||
transparency=None,
|
||||
),
|
||||
Brightness(
|
||||
brightness_range=(0.9, 1.1),
|
||||
min_brightness=0,
|
||||
min_brightness_value=(120, 150),
|
||||
),
|
||||
Gamma(
|
||||
gamma_range=(0.9, 1.1),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
SubtleNoise(
|
||||
subtle_range=random.randint(5, 10),
|
||||
),
|
||||
Jpeg(
|
||||
quality_range=(70, 95),
|
||||
),
|
||||
],
|
||||
# p=0.2
|
||||
p=0.4,
|
||||
),
|
||||
# =====================================
|
||||
]
|
||||
|
||||
pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase,
|
||||
paper_phase=paper_phase,
|
||||
post_phase=post_phase,
|
||||
pre_phase=pre_phase,
|
||||
log=False,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
@@ -1,194 +0,0 @@
|
||||
import re
|
||||
|
||||
from .latex_formatter import format_latex
|
||||
|
||||
|
||||
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
result = ""
|
||||
i = 0
|
||||
n = len(input_str)
|
||||
|
||||
while i < n:
|
||||
if input_str[i : i + len(old_inst)] == old_inst:
|
||||
# check if the old_inst is followed by old_surr_l
|
||||
start = i + len(old_inst)
|
||||
else:
|
||||
result += input_str[i]
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if start < n and input_str[start] == old_surr_l:
|
||||
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
||||
count = 1
|
||||
j = start + 1
|
||||
escaped = False
|
||||
while j < n and count > 0:
|
||||
if input_str[j] == '\\' and not escaped:
|
||||
escaped = True
|
||||
j += 1
|
||||
continue
|
||||
if input_str[j] == old_surr_r and not escaped:
|
||||
count -= 1
|
||||
if count == 0:
|
||||
break
|
||||
elif input_str[j] == old_surr_l and not escaped:
|
||||
count += 1
|
||||
escaped = False
|
||||
j += 1
|
||||
|
||||
if count == 0:
|
||||
assert j < n
|
||||
assert input_str[start] == old_surr_l
|
||||
assert input_str[j] == old_surr_r
|
||||
inner_content = input_str[start + 1 : j]
|
||||
# Replace the content with new pattern
|
||||
result += new_inst + new_surr_l + inner_content + new_surr_r
|
||||
i = j + 1
|
||||
continue
|
||||
else:
|
||||
assert count >= 1
|
||||
assert j == n
|
||||
print("Warning: unbalanced surrogate pair in input string")
|
||||
result += new_inst + new_surr_l
|
||||
i = start + 1
|
||||
continue
|
||||
else:
|
||||
result += input_str[i:start]
|
||||
i = start
|
||||
|
||||
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
||||
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def find_substring_positions(string, substring):
|
||||
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
||||
return positions
|
||||
|
||||
|
||||
def rm_dollar_surr(content):
|
||||
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
||||
matches = pattern.findall(content)
|
||||
|
||||
for match in matches:
|
||||
if not re.match(r'\\[a-zA-Z]+', match):
|
||||
new_match = match.strip('$')
|
||||
content = content.replace(match, ' ' + new_match + ' ')
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||
pos = find_substring_positions(input_str, old_inst + old_surr_l)
|
||||
res = list(input_str)
|
||||
for p in pos[::-1]:
|
||||
res[p:] = list(
|
||||
change(
|
||||
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
|
||||
)
|
||||
)
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
|
||||
def to_katex(formula: str) -> str:
|
||||
res = formula
|
||||
# remove mbox surrounding
|
||||
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
|
||||
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
|
||||
# remove hbox surrounding
|
||||
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
|
||||
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove raise surrounding
|
||||
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
|
||||
# remove makebox
|
||||
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
|
||||
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove vbox surrounding, scalebox surrounding
|
||||
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
|
||||
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
|
||||
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
origin_instructions = [
|
||||
r'\Huge',
|
||||
r'\huge',
|
||||
r'\LARGE',
|
||||
r'\Large',
|
||||
r'\large',
|
||||
r'\normalsize',
|
||||
r'\small',
|
||||
r'\footnotesize',
|
||||
r'\tiny',
|
||||
]
|
||||
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
|
||||
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
|
||||
res = change_all(res, r'\mathbf', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
|
||||
|
||||
# remove bold command
|
||||
res = change_all(res, r'\bm', r' ', r'{', r'}', r'', r'')
|
||||
|
||||
origin_instructions = [
|
||||
r'\left',
|
||||
r'\middle',
|
||||
r'\right',
|
||||
r'\big',
|
||||
r'\Big',
|
||||
r'\bigg',
|
||||
r'\Bigg',
|
||||
r'\bigl',
|
||||
r'\Bigl',
|
||||
r'\biggl',
|
||||
r'\Biggl',
|
||||
r'\bigm',
|
||||
r'\Bigm',
|
||||
r'\biggm',
|
||||
r'\Biggm',
|
||||
r'\bigr',
|
||||
r'\Bigr',
|
||||
r'\biggr',
|
||||
r'\Biggr',
|
||||
]
|
||||
for origin_ins in origin_instructions:
|
||||
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
|
||||
|
||||
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
||||
|
||||
if res.endswith(r'\newline'):
|
||||
res = res[:-8]
|
||||
|
||||
# remove multiple spaces
|
||||
res = re.sub(r'(\\,){1,}', ' ', res)
|
||||
res = re.sub(r'(\\!){1,}', ' ', res)
|
||||
res = re.sub(r'(\\;){1,}', ' ', res)
|
||||
res = re.sub(r'(\\:){1,}', ' ', res)
|
||||
res = re.sub(r'\\vspace\{.*?}', '', res)
|
||||
|
||||
# merge consecutive text
|
||||
def merge_texts(match):
|
||||
texts = match.group(0)
|
||||
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
|
||||
return f'\\text{{{merged_content}}}'
|
||||
|
||||
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
||||
|
||||
res = res.replace(r'\bf ', '')
|
||||
res = rm_dollar_surr(res)
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
res = re.sub(r' +', ' ', res)
|
||||
|
||||
# format latex
|
||||
res = res.strip()
|
||||
res, logs = format_latex(res)
|
||||
|
||||
return res
|
||||
@@ -1,177 +0,0 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from ...globals import (
|
||||
IMG_CHANNELS,
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN,
|
||||
IMAGE_STD,
|
||||
MAX_RESIZE_RATIO,
|
||||
MIN_RESIZE_RATIO,
|
||||
)
|
||||
from .ocr_aug import ocr_augmentation_pipeline
|
||||
|
||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||
train_pipeline = ocr_augmentation_pipeline()
|
||||
|
||||
general_transform_pipeline = v2.Compose(
|
||||
[
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True,
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
# v2.ToPILImage()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold = 15
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if pad_height_size + image.shape[0] < 30:
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if pad_width_size + image.shape[1] < 30:
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode='constant',
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(
|
||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||
) # 抗锯齿
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(
|
||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||
)
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
image = train_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = [trim_white_border(image) for image in images]
|
||||
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
images = [
|
||||
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
||||
]
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
|
||||
return images
|
||||
48
texteller/models/texteller.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
|
||||
|
||||
from texteller.constants import (
|
||||
FIXED_IMG_SIZE,
|
||||
IMG_CHANNELS,
|
||||
MAX_TOKEN_SIZE,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
from texteller.globals import Globals
|
||||
from texteller.types import TexTellerModel
|
||||
from texteller.utils import cuda_available
|
||||
|
||||
|
||||
class TexTeller(VisionEncoderDecoderModel):
|
||||
def __init__(self):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
|
||||
config.encoder.image_size = FIXED_IMG_SIZE
|
||||
config.encoder.num_channels = IMG_CHANNELS
|
||||
config.decoder.vocab_size = VOCAB_SIZE
|
||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||
|
||||
super().__init__(config=config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
|
||||
if model_dir is None or model_dir == Globals().repo_name:
|
||||
if not use_onnx:
|
||||
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
|
||||
else:
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
|
||||
return ORTModelForVision2Seq.from_pretrained(
|
||||
Globals().repo_name,
|
||||
provider="CUDAExecutionProvider"
|
||||
if cuda_available()
|
||||
else "CPUExecutionProvider",
|
||||
)
|
||||
model_dir = Path(model_dir).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
|
||||
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
|
||||
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
|
||||
tokenizer_dir = Path(tokenizer_dir).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))
|
||||
@@ -1,212 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
"""Convert between text-label and text-index"""
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
cur_path = os.getcwd()
|
||||
scriptDir = Path(__file__).resolve().parent
|
||||
os.chdir(scriptDir)
|
||||
character_dict_path = str(Path(scriptDir / "ppocr_keys_v1.txt"))
|
||||
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if "arabic" in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
os.chdir(cur_path)
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ""
|
||||
for c in pred:
|
||||
if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
|
||||
if c_current != "":
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ""
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != "":
|
||||
pred_re.append(c_current)
|
||||
|
||||
return "".join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def get_word_info(self, text, selection):
|
||||
"""
|
||||
Group the decoded characters and record the corresponding decoded positions.
|
||||
|
||||
Args:
|
||||
text: the decoded text
|
||||
selection: the bool array that identifies which columns of features are decoded as non-separated characters
|
||||
Returns:
|
||||
word_list: list of the grouped words
|
||||
word_col_list: list of decoding positions corresponding to each character in the grouped word
|
||||
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
|
||||
- 'cn': continous chinese characters (e.g., 你好啊)
|
||||
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
|
||||
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
|
||||
"""
|
||||
state = None
|
||||
word_content = []
|
||||
word_col_content = []
|
||||
word_list = []
|
||||
word_col_list = []
|
||||
state_list = []
|
||||
valid_col = np.where(selection == True)[0]
|
||||
|
||||
for c_i, char in enumerate(text):
|
||||
if "\u4e00" <= char <= "\u9fff":
|
||||
c_state = "cn"
|
||||
elif bool(re.search("[a-zA-Z0-9]", char)):
|
||||
c_state = "en&num"
|
||||
else:
|
||||
c_state = "splitter"
|
||||
|
||||
if (
|
||||
char == "."
|
||||
and state == "en&num"
|
||||
and c_i + 1 < len(text)
|
||||
and bool(re.search("[0-9]", text[c_i + 1]))
|
||||
): # grouping floting number
|
||||
c_state = "en&num"
|
||||
if (
|
||||
char == "-" and state == "en&num"
|
||||
): # grouping word with '-', such as 'state-of-the-art'
|
||||
c_state = "en&num"
|
||||
|
||||
if state is None:
|
||||
state = c_state
|
||||
|
||||
if state != c_state:
|
||||
if len(word_content) != 0:
|
||||
word_list.append(word_content)
|
||||
word_col_list.append(word_col_content)
|
||||
state_list.append(state)
|
||||
word_content = []
|
||||
word_col_content = []
|
||||
state = c_state
|
||||
|
||||
if state != "splitter":
|
||||
word_content.append(char)
|
||||
word_col_content.append(valid_col[c_i])
|
||||
|
||||
if len(word_content) != 0:
|
||||
word_list.append(word_content)
|
||||
word_col_list.append(word_col_content)
|
||||
state_list.append(state)
|
||||
|
||||
return word_list, word_col_list, state_list
|
||||
|
||||
def decode(
|
||||
self,
|
||||
text_index,
|
||||
text_prob=None,
|
||||
is_remove_duplicate=False,
|
||||
return_word_box=False,
|
||||
):
|
||||
"""convert text-index into text-label."""
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = "".join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
if return_word_box:
|
||||
word_list, word_col_list, state_list = self.get_word_info(text, selection)
|
||||
result_list.append(
|
||||
(
|
||||
text,
|
||||
np.mean(conf_list).tolist(),
|
||||
[
|
||||
len(text_index[batch_idx]),
|
||||
word_list,
|
||||
word_col_list,
|
||||
state_list,
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
"""Convert between text-label and text-index"""
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
assert isinstance(preds, np.ndarray)
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(
|
||||
preds_idx,
|
||||
preds_prob,
|
||||
is_remove_duplicate=True,
|
||||
return_word_box=return_word_box,
|
||||
)
|
||||
if return_word_box:
|
||||
for rec_idx, rec in enumerate(text):
|
||||
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
|
||||
max_wh_ratio = kwargs["max_wh_ratio"]
|
||||
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ["blank"] + dict_character
|
||||
return dict_character
|
||||
@@ -1,221 +0,0 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type="quad",
|
||||
**kwargs,
|
||||
):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow",
|
||||
"fast",
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours(
|
||||
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
for contour in contours[: self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours(
|
||||
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
if len(outs) == 3:
|
||||
img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
"""
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
"""
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
"""
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
"""
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict["maps"]
|
||||
assert isinstance(pred, np.ndarray)
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel,
|
||||
)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == "poly":
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index], mask, src_w, src_h)
|
||||
elif self.box_type == "quad":
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
|
||||
else:
|
||||
raise ValueError("box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({"points": boxes})
|
||||
return boxes_batch
|
||||
@@ -1,186 +0,0 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
import sys
|
||||
|
||||
|
||||
class DetResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
self.keep_ratio = False
|
||||
if "image_shape" in kwargs:
|
||||
self.image_shape = kwargs["image_shape"]
|
||||
self.resize_type = 1
|
||||
if "keep_ratio" in kwargs:
|
||||
self.keep_ratio = kwargs["keep_ratio"]
|
||||
elif "limit_side_len" in kwargs:
|
||||
self.limit_side_len = kwargs["limit_side_len"]
|
||||
self.limit_type = kwargs.get("limit_type", "min")
|
||||
elif "resize_long" in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get("resize_long", 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = "min"
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
src_h, src_w, _ = img.shape
|
||||
if sum([src_h, src_w]) < 64:
|
||||
img = self.image_padding(img)
|
||||
|
||||
if self.resize_type == 0:
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data["image"] = img
|
||||
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def image_padding(self, im, value=0):
|
||||
h, w, c = im.shape
|
||||
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
||||
im_pad[:h, :w, :] = im
|
||||
return im_pad
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
if self.keep_ratio is True:
|
||||
resize_w = ori_w * resize_h / ori_h
|
||||
N = math.ceil(resize_w / 32)
|
||||
resize_w = N * 32
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == "max":
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.0
|
||||
elif self.limit_type == "min":
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.0
|
||||
elif self.limit_type == "resize_long":
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception("not support limit type, image ")
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except: # noqa: E722
|
||||
print(img.shape, resize_w, resize_h)
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""normalize image such as substract mean, divide std"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype("float32")
|
||||
self.std = np.array(std).reshape(shape).astype("float32")
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
"""convert hwc image to chw image"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data["image"] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
@@ -1,286 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# import tools.infer.utility as utility
|
||||
import utility
|
||||
from DBPostProcess import DBPostProcess
|
||||
from operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
|
||||
from utility import get_logger
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
"""transform"""
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_onnx = args.use_onnx
|
||||
postprocess_params = {}
|
||||
assert self.det_algorithm == "DB"
|
||||
postprocess_params["name"] = "DBPostProcess"
|
||||
postprocess_params["thresh"] = args.det_db_thresh
|
||||
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
||||
postprocess_params["max_candidates"] = 1000
|
||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
postprocess_params["box_type"] = args.det_box_type
|
||||
|
||||
self.preprocess_op = [
|
||||
DetResizeForTest(
|
||||
limit_side_len=args.det_limit_side_len, limit_type=args.det_limit_type
|
||||
),
|
||||
NormalizeImage(
|
||||
std=[0.229, 0.224, 0.225],
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
scale=1.0 / 255.0,
|
||||
order="hwc",
|
||||
),
|
||||
ToCHWImage(),
|
||||
KeepKeys(keep_keys=["image", "shape"]),
|
||||
]
|
||||
self.postprocess_op = DBPostProcess(**postprocess_params)
|
||||
(
|
||||
self.predictor,
|
||||
self.input_tensor,
|
||||
self.output_tensors,
|
||||
self.config,
|
||||
) = utility.create_predictor(args, "det", logger)
|
||||
|
||||
assert self.use_onnx
|
||||
if self.use_onnx:
|
||||
img_h, img_w = self.input_tensor.shape[2:]
|
||||
if isinstance(img_h, str) or isinstance(img_w, str):
|
||||
pass
|
||||
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
||||
self.preprocess_op[0] = DetResizeForTest(image_shape=[img_h, img_w])
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
rect[2] = pts[np.argmax(s)]
|
||||
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
||||
diff = np.diff(np.array(tmp), axis=1)
|
||||
rect[1] = tmp[np.argmin(diff)]
|
||||
rect[3] = tmp[np.argmax(diff)]
|
||||
return rect
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if type(box) is list:
|
||||
box = np.array(box)
|
||||
box = self.order_points_clockwise(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
||||
if rect_width <= 3 or rect_height <= 3:
|
||||
continue
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if type(box) is list:
|
||||
box = np.array(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def predict(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {"image": img}
|
||||
|
||||
st = time.time()
|
||||
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.start()
|
||||
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
if self.use_onnx:
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = img
|
||||
outputs = self.predictor.run(self.output_tensors, input_dict)
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
preds = {}
|
||||
if self.det_algorithm == "EAST":
|
||||
preds["f_geo"] = outputs[0]
|
||||
preds["f_score"] = outputs[1]
|
||||
elif self.det_algorithm == "SAST":
|
||||
preds["f_border"] = outputs[0]
|
||||
preds["f_score"] = outputs[1]
|
||||
preds["f_tco"] = outputs[2]
|
||||
preds["f_tvo"] = outputs[3]
|
||||
elif self.det_algorithm in ["DB", "PSE", "DB++"]:
|
||||
preds["maps"] = outputs[0]
|
||||
elif self.det_algorithm == "FCE":
|
||||
for i, output in enumerate(outputs):
|
||||
preds["level_{}".format(i)] = output
|
||||
elif self.det_algorithm == "CT":
|
||||
preds["maps"] = outputs[0]
|
||||
preds["score"] = outputs[1]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]["points"]
|
||||
|
||||
if self.args.det_box_type == "poly":
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
et = time.time()
|
||||
return dt_boxes, et - st
|
||||
|
||||
def __call__(self, img):
|
||||
# For image like poster with one side much greater than the other side,
|
||||
# splitting recursively and processing with overlap to enhance performance.
|
||||
MIN_BOUND_DISTANCE = 50
|
||||
dt_boxes = np.zeros((0, 4, 2), dtype=np.float32)
|
||||
elapse = 0
|
||||
if img.shape[0] / img.shape[1] > 2 and img.shape[0] > self.args.det_limit_side_len:
|
||||
start_h = 0
|
||||
end_h = 0
|
||||
while end_h <= img.shape[0]:
|
||||
end_h = start_h + img.shape[1] * 3 // 4
|
||||
subimg = img[start_h:end_h, :]
|
||||
if len(subimg) == 0:
|
||||
break
|
||||
sub_dt_boxes, sub_elapse = self.predict(subimg)
|
||||
offset = start_h
|
||||
# To prevent text blocks from being cut off, roll back a certain buffer area.
|
||||
if (
|
||||
len(sub_dt_boxes) == 0
|
||||
or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE
|
||||
):
|
||||
start_h = end_h
|
||||
else:
|
||||
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1])
|
||||
sub_dt_boxes = sub_dt_boxes[sorted_indices]
|
||||
bottom_line = (
|
||||
0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 2, 1]))
|
||||
)
|
||||
if bottom_line > 0:
|
||||
start_h += bottom_line
|
||||
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 2, 1] <= bottom_line]
|
||||
else:
|
||||
start_h = end_h
|
||||
if len(sub_dt_boxes) > 0:
|
||||
if dt_boxes.shape[0] == 0:
|
||||
dt_boxes = sub_dt_boxes + np.array([0, offset], dtype=np.float32)
|
||||
else:
|
||||
dt_boxes = np.append(
|
||||
dt_boxes,
|
||||
sub_dt_boxes + np.array([0, offset], dtype=np.float32),
|
||||
axis=0,
|
||||
)
|
||||
elapse += sub_elapse
|
||||
elif img.shape[1] / img.shape[0] > 3 and img.shape[1] > self.args.det_limit_side_len * 3:
|
||||
start_w = 0
|
||||
end_w = 0
|
||||
while end_w <= img.shape[1]:
|
||||
end_w = start_w + img.shape[0] * 3 // 4
|
||||
subimg = img[:, start_w:end_w]
|
||||
if len(subimg) == 0:
|
||||
break
|
||||
sub_dt_boxes, sub_elapse = self.predict(subimg)
|
||||
offset = start_w
|
||||
if (
|
||||
len(sub_dt_boxes) == 0
|
||||
or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes]) > MIN_BOUND_DISTANCE
|
||||
):
|
||||
start_w = end_w
|
||||
else:
|
||||
sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0])
|
||||
sub_dt_boxes = sub_dt_boxes[sorted_indices]
|
||||
right_line = (
|
||||
0 if len(sub_dt_boxes) <= 1 else int(np.max(sub_dt_boxes[:-1, 1, 0]))
|
||||
)
|
||||
if right_line > 0:
|
||||
start_w += right_line
|
||||
sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line]
|
||||
else:
|
||||
start_w = end_w
|
||||
if len(sub_dt_boxes) > 0:
|
||||
if dt_boxes.shape[0] == 0:
|
||||
dt_boxes = sub_dt_boxes + np.array([offset, 0], dtype=np.float32)
|
||||
else:
|
||||
dt_boxes = np.append(
|
||||
dt_boxes,
|
||||
sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
|
||||
axis=0,
|
||||
)
|
||||
elapse += sub_elapse
|
||||
else:
|
||||
dt_boxes, elapse = self.predict(img)
|
||||
return dt_boxes, elapse
|
||||
@@ -1,379 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
import utility
|
||||
from utility import get_logger
|
||||
|
||||
from CTCLabelDecode import CTCLabelDecode
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, args):
|
||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.rec_algorithm = args.rec_algorithm
|
||||
self.postprocess_op = CTCLabelDecode(
|
||||
character_dict_path=args.rec_char_dict_path, use_space_char=args.use_space_char
|
||||
)
|
||||
(
|
||||
self.predictor,
|
||||
self.input_tensor,
|
||||
self.output_tensors,
|
||||
self.config,
|
||||
) = utility.create_predictor(args, "rec", logger)
|
||||
self.benchmark = args.benchmark
|
||||
self.use_onnx = args.use_onnx
|
||||
self.return_word_box = args.return_word_box
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
if self.rec_algorithm == "ViTSTR":
|
||||
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
|
||||
else:
|
||||
img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
|
||||
img = np.array(img)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
if self.rec_algorithm == "ViTSTR":
|
||||
norm_img = norm_img.astype(np.float32) / 255.0
|
||||
else:
|
||||
norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
|
||||
return norm_img
|
||||
elif self.rec_algorithm == "RFL":
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
||||
resized_image = resized_image.astype("float32")
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
if self.use_onnx:
|
||||
w = self.input_tensor.shape[3:][0]
|
||||
if isinstance(w, str):
|
||||
pass
|
||||
elif w is not None and w > 0:
|
||||
imgW = w
|
||||
h, w = img.shape[:2]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
if self.rec_algorithm == "RARE":
|
||||
if resized_w > self.rec_image_shape[2]:
|
||||
resized_w = self.rec_image_shape[2]
|
||||
imgW = self.rec_image_shape[2]
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype("float32")
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
img = img[:, :, ::-1] # bgr2rgb
|
||||
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype("float32")
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0 : img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
|
||||
gsrm_word_pos = (
|
||||
np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype("int64")
|
||||
)
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length]
|
||||
)
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
|
||||
"float32"
|
||||
) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length]
|
||||
)
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
|
||||
"float32"
|
||||
) * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos,
|
||||
gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2,
|
||||
]
|
||||
|
||||
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
[
|
||||
encoder_word_pos,
|
||||
gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2,
|
||||
] = self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||
|
||||
return (
|
||||
norm_img,
|
||||
encoder_word_pos,
|
||||
gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2,
|
||||
)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype("float32")
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_spin(self, img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
mean = [127.5]
|
||||
std = [127.5]
|
||||
mean = np.array(mean, dtype=np.float32)
|
||||
std = np.array(std, dtype=np.float32)
|
||||
mean = np.float32(mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype("float32")
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_cppd_padding(
|
||||
self, img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR
|
||||
):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
if not padding:
|
||||
resized_image = cv2.resize(img, (imgW, imgH), interpolation=interpolation)
|
||||
resized_w = imgW
|
||||
else:
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype("float32")
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_abinet(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype("float32")
|
||||
resized_image = resized_image / 255.0
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype("float32")
|
||||
|
||||
return resized_image
|
||||
|
||||
def norm_img_can(self, img, image_shape):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||
|
||||
if self.inverse:
|
||||
img = 255 - img
|
||||
|
||||
if self.rec_image_shape[0] == 1:
|
||||
h, w = img.shape
|
||||
_, imgH, imgW = self.rec_image_shape
|
||||
if h < imgH or w < imgW:
|
||||
padding_h = max(imgH - h, 0)
|
||||
padding_w = max(imgW - w, 0)
|
||||
img_padded = np.pad(
|
||||
img,
|
||||
((0, padding_h), (0, padding_w)),
|
||||
"constant",
|
||||
constant_values=(255),
|
||||
)
|
||||
img = img_padded
|
||||
|
||||
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||
img = img.astype("float32")
|
||||
|
||||
return img
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
rec_res = [["", 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
wh_ratio_list = []
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
wh_ratio_list.append(wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
assert self.use_onnx
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = norm_img_batch
|
||||
outputs = self.predictor.run(self.output_tensors, input_dict)
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(
|
||||
preds,
|
||||
return_word_box=self.return_word_box,
|
||||
wh_ratio_list=wh_ratio_list,
|
||||
max_wh_ratio=max_wh_ratio,
|
||||
)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return rec_res, time.time() - st
|
||||
@@ -1,689 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import functools
|
||||
import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "yes", "t", "y", "1")
|
||||
|
||||
|
||||
def str2int_tuple(v):
|
||||
return tuple([int(i.strip()) for i in v.split(",")])
|
||||
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--use_xpu", type=str2bool, default=False)
|
||||
parser.add_argument("--use_npu", type=str2bool, default=False)
|
||||
parser.add_argument("--use_mlu", type=str2bool, default=False)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=15)
|
||||
parser.add_argument("--precision", type=str, default="fp32")
|
||||
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||
parser.add_argument("--gpu_id", type=int, default=0)
|
||||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--page_num", type=int, default=0)
|
||||
parser.add_argument("--det_algorithm", type=str, default="DB")
|
||||
parser.add_argument("--det_model_dir", type=str)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
parser.add_argument("--det_limit_type", type=str, default="max")
|
||||
parser.add_argument("--det_box_type", type=str, default="quad")
|
||||
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
# SAST parmas
|
||||
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
|
||||
# PSE parmas
|
||||
parser.add_argument("--det_pse_thresh", type=float, default=0)
|
||||
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
||||
parser.add_argument("--det_pse_min_area", type=float, default=16)
|
||||
parser.add_argument("--det_pse_scale", type=int, default=1)
|
||||
|
||||
# FCE parmas
|
||||
parser.add_argument("--scales", type=list, default=[8, 16, 32])
|
||||
parser.add_argument("--alpha", type=float, default=1.0)
|
||||
parser.add_argument("--beta", type=float, default=1.0)
|
||||
parser.add_argument("--fourier_degree", type=int, default=5)
|
||||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
|
||||
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument("--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||
parser.add_argument("--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
|
||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||
|
||||
# params for e2e
|
||||
parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
|
||||
parser.add_argument("--e2e_model_dir", type=str)
|
||||
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
||||
parser.add_argument("--e2e_limit_type", type=str, default="max")
|
||||
|
||||
# PGNet parmas
|
||||
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
|
||||
parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
|
||||
parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=["0", "180"])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=6)
|
||||
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--cpu_threads", type=int, default=10)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
parser.add_argument("--warmup", type=str2bool, default=False)
|
||||
|
||||
# SR parmas
|
||||
parser.add_argument("--sr_model_dir", type=str)
|
||||
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
|
||||
parser.add_argument("--sr_batch_num", type=int, default=1)
|
||||
|
||||
#
|
||||
parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results")
|
||||
parser.add_argument("--save_crop_res", type=str2bool, default=False)
|
||||
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
|
||||
|
||||
# multi-process
|
||||
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||
parser.add_argument("--total_process_num", type=int, default=1)
|
||||
parser.add_argument("--process_id", type=int, default=0)
|
||||
|
||||
parser.add_argument("--benchmark", type=str2bool, default=False)
|
||||
parser.add_argument("--save_log_path", type=str, default="./log_output/")
|
||||
|
||||
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||
parser.add_argument("--use_onnx", type=str2bool, default=False)
|
||||
|
||||
# extended function
|
||||
parser.add_argument(
|
||||
"--return_word_box",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = init_args()
|
||||
return parser.parse_args([])
|
||||
|
||||
|
||||
def create_predictor(args, mode, logger):
|
||||
if mode == "det":
|
||||
model_dir = args.det_model_dir
|
||||
elif mode == "cls":
|
||||
model_dir = args.cls_model_dir
|
||||
elif mode == "rec":
|
||||
model_dir = args.rec_model_dir
|
||||
elif mode == "table":
|
||||
model_dir = args.table_model_dir
|
||||
elif mode == "ser":
|
||||
model_dir = args.ser_model_dir
|
||||
elif mode == "re":
|
||||
model_dir = args.re_model_dir
|
||||
elif mode == "sr":
|
||||
model_dir = args.sr_model_dir
|
||||
elif mode == "layout":
|
||||
model_dir = args.layout_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
if model_dir is None:
|
||||
logger.info("not find {} model file path {}".format(mode, model_dir))
|
||||
sys.exit(0)
|
||||
assert args.use_onnx
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
model_file_path = model_dir
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(model_file_path))
|
||||
if args.use_gpu:
|
||||
sess = ort.InferenceSession(model_file_path, providers=["CUDAExecutionProvider"])
|
||||
else:
|
||||
sess = ort.InferenceSession(model_file_path)
|
||||
return sess, sess.get_inputs()[0], None, None
|
||||
|
||||
|
||||
def get_output_tensors(args, mode, predictor):
|
||||
output_names = predictor.get_output_names()
|
||||
output_tensors = []
|
||||
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet", "SVTR_HGNet"]:
|
||||
output_name = "softmax_0.tmp_0"
|
||||
if output_name in output_names:
|
||||
return [predictor.get_output_handle(output_name)]
|
||||
else:
|
||||
for output_name in output_names:
|
||||
output_tensor = predictor.get_output_handle(output_name)
|
||||
output_tensors.append(output_tensor)
|
||||
else:
|
||||
for output_name in output_names:
|
||||
output_tensor = predictor.get_output_handle(output_name)
|
||||
output_tensors.append(output_tensor)
|
||||
return output_tensors
|
||||
|
||||
|
||||
def draw_e2e_res(dt_boxes, strs, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
for box, str in zip(dt_boxes, strs):
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
cv2.putText(
|
||||
src_im,
|
||||
str,
|
||||
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
||||
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
||||
fontScale=0.7,
|
||||
color=(0, 255, 0),
|
||||
thickness=1,
|
||||
)
|
||||
return src_im
|
||||
|
||||
|
||||
def draw_text_det_res(dt_boxes, img):
|
||||
for box in dt_boxes:
|
||||
box = np.array(box).astype(np.int32).reshape(-1, 2)
|
||||
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
||||
return img
|
||||
|
||||
|
||||
def resize_img(img, input_size=600):
|
||||
"""
|
||||
resize img and limit the longest side of the image to input_size
|
||||
"""
|
||||
img = np.array(img)
|
||||
im_shape = img.shape
|
||||
im_size_max = np.max(im_shape[0:2])
|
||||
im_scale = float(input_size) / float(im_size_max)
|
||||
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
|
||||
return img
|
||||
|
||||
|
||||
def draw_ocr(
|
||||
image,
|
||||
boxes,
|
||||
txts=None,
|
||||
scores=None,
|
||||
drop_score=0.5,
|
||||
font_path="./doc/fonts/simfang.ttf",
|
||||
):
|
||||
"""
|
||||
Visualize the results of OCR detection and recognition
|
||||
args:
|
||||
image(Image|array): RGB image
|
||||
boxes(list): boxes with shape(N, 4, 2)
|
||||
txts(list): the texts
|
||||
scores(list): txxs corresponding scores
|
||||
drop_score(float): only scores greater than drop_threshold will be visualized
|
||||
font_path: the path of font which is used to draw text
|
||||
return(array):
|
||||
the visualized img
|
||||
"""
|
||||
if scores is None:
|
||||
scores = [1] * len(boxes)
|
||||
box_num = len(boxes)
|
||||
for i in range(box_num):
|
||||
if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
|
||||
continue
|
||||
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
|
||||
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
||||
if txts is not None:
|
||||
img = np.array(resize_img(image, input_size=600))
|
||||
txt_img = text_visual(
|
||||
txts,
|
||||
scores,
|
||||
img_h=img.shape[0],
|
||||
img_w=600,
|
||||
threshold=drop_score,
|
||||
font_path=font_path,
|
||||
)
|
||||
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
|
||||
return img
|
||||
return image
|
||||
|
||||
|
||||
def draw_ocr_box_txt(
|
||||
image,
|
||||
boxes,
|
||||
txts=None,
|
||||
scores=None,
|
||||
drop_score=0.5,
|
||||
font_path="./doc/fonts/simfang.ttf",
|
||||
):
|
||||
h, w = image.height, image.width
|
||||
img_left = image.copy()
|
||||
img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
random.seed(0)
|
||||
|
||||
draw_left = ImageDraw.Draw(img_left)
|
||||
if txts is None or len(txts) != len(boxes):
|
||||
txts = [None] * len(boxes)
|
||||
for idx, (box, txt) in enumerate(zip(boxes, txts)):
|
||||
if scores is not None and scores[idx] < drop_score:
|
||||
continue
|
||||
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
draw_left.polygon(box, fill=color)
|
||||
img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
|
||||
pts = np.array(box, np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(img_right_text, [pts], True, color, 1)
|
||||
img_right = cv2.bitwise_and(img_right, img_right_text)
|
||||
img_left = Image.blend(image, img_left, 0.5)
|
||||
img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
|
||||
img_show.paste(img_left, (0, 0, w, h))
|
||||
img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
|
||||
return np.array(img_show)
|
||||
|
||||
|
||||
def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
|
||||
box_height = int(math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2))
|
||||
box_width = int(math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2))
|
||||
|
||||
if box_height > 2 * box_width and box_height > 30:
|
||||
img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
|
||||
draw_text = ImageDraw.Draw(img_text)
|
||||
if txt:
|
||||
font = create_font(txt, (box_height, box_width), font_path)
|
||||
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
|
||||
img_text = img_text.transpose(Image.ROTATE_270)
|
||||
else:
|
||||
img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
|
||||
draw_text = ImageDraw.Draw(img_text)
|
||||
if txt:
|
||||
font = create_font(txt, (box_width, box_height), font_path)
|
||||
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
|
||||
|
||||
pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
|
||||
pts2 = np.array(box, dtype=np.float32)
|
||||
M = cv2.getPerspectiveTransform(pts1, pts2)
|
||||
|
||||
img_text = np.array(img_text, dtype=np.uint8)
|
||||
img_right_text = cv2.warpPerspective(
|
||||
img_text,
|
||||
M,
|
||||
img_size,
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255),
|
||||
)
|
||||
return img_right_text
|
||||
|
||||
|
||||
def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
|
||||
font_size = int(sz[1] * 0.99)
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
if int(PIL.__version__.split(".")[0]) < 10:
|
||||
length = font.getsize(txt)[0]
|
||||
else:
|
||||
length = font.getlength(txt)
|
||||
|
||||
if length > sz[0]:
|
||||
font_size = int(font_size * sz[0] / length)
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
return font
|
||||
|
||||
|
||||
def str_count(s):
|
||||
"""
|
||||
Count the number of Chinese characters,
|
||||
a single English character and a single number
|
||||
equal to half the length of Chinese characters.
|
||||
args:
|
||||
s(string): the input of string
|
||||
return(int):
|
||||
the number of Chinese characters
|
||||
"""
|
||||
import string
|
||||
|
||||
count_zh = count_pu = 0
|
||||
s_len = len(s)
|
||||
en_dg_count = 0
|
||||
for c in s:
|
||||
if c in string.ascii_letters or c.isdigit() or c.isspace():
|
||||
en_dg_count += 1
|
||||
elif c.isalpha():
|
||||
count_zh += 1
|
||||
else:
|
||||
count_pu += 1
|
||||
return s_len - math.ceil(en_dg_count / 2)
|
||||
|
||||
|
||||
def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf"):
|
||||
"""
|
||||
create new blank img and draw txt on it
|
||||
args:
|
||||
texts(list): the text will be draw
|
||||
scores(list|None): corresponding score of each txt
|
||||
img_h(int): the height of blank img
|
||||
img_w(int): the width of blank img
|
||||
font_path: the path of font which is used to draw text
|
||||
return(array):
|
||||
"""
|
||||
if scores is not None:
|
||||
assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
|
||||
|
||||
def create_blank_img():
|
||||
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
|
||||
blank_img[:, img_w - 1 :] = 0
|
||||
blank_img = Image.fromarray(blank_img).convert("RGB")
|
||||
draw_txt = ImageDraw.Draw(blank_img)
|
||||
return blank_img, draw_txt
|
||||
|
||||
blank_img, draw_txt = create_blank_img()
|
||||
|
||||
font_size = 20
|
||||
txt_color = (0, 0, 0)
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
|
||||
gap = font_size + 5
|
||||
txt_img_list = []
|
||||
count, index = 1, 0
|
||||
for idx, txt in enumerate(texts):
|
||||
index += 1
|
||||
if scores[idx] < threshold or math.isnan(scores[idx]):
|
||||
index -= 1
|
||||
continue
|
||||
first_line = True
|
||||
while str_count(txt) >= img_w // font_size - 4:
|
||||
tmp = txt
|
||||
txt = tmp[: img_w // font_size - 4]
|
||||
if first_line:
|
||||
new_txt = str(index) + ": " + txt
|
||||
first_line = False
|
||||
else:
|
||||
new_txt = " " + txt
|
||||
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||
txt = tmp[img_w // font_size - 4 :]
|
||||
if count >= img_h // gap - 1:
|
||||
txt_img_list.append(np.array(blank_img))
|
||||
blank_img, draw_txt = create_blank_img()
|
||||
count = 0
|
||||
count += 1
|
||||
if first_line:
|
||||
new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
|
||||
else:
|
||||
new_txt = " " + txt + " " + "%.3f" % (scores[idx])
|
||||
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||
# whether add new blank img or not
|
||||
if count >= img_h // gap - 1 and idx + 1 < len(texts):
|
||||
txt_img_list.append(np.array(blank_img))
|
||||
blank_img, draw_txt = create_blank_img()
|
||||
count = 0
|
||||
count += 1
|
||||
txt_img_list.append(np.array(blank_img))
|
||||
if len(txt_img_list) == 1:
|
||||
blank_img = np.array(txt_img_list[0])
|
||||
else:
|
||||
blank_img = np.concatenate(txt_img_list, axis=1)
|
||||
return np.array(blank_img)
|
||||
|
||||
|
||||
def base64_to_cv2(b64str):
|
||||
import base64
|
||||
|
||||
data = base64.b64decode(b64str.encode("utf8"))
|
||||
data = np.frombuffer(data, np.uint8)
|
||||
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
return data
|
||||
|
||||
|
||||
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
|
||||
if scores is None:
|
||||
scores = [1] * len(boxes)
|
||||
for box, score in zip(boxes, scores):
|
||||
if score < drop_score:
|
||||
continue
|
||||
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
|
||||
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
||||
return image
|
||||
|
||||
|
||||
def get_rotate_crop_image(img, points):
|
||||
"""
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
"""
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))
|
||||
)
|
||||
img_crop_height = int(
|
||||
max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))
|
||||
)
|
||||
pts_std = np.float32(
|
||||
[
|
||||
[0, 0],
|
||||
[img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height],
|
||||
]
|
||||
)
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M,
|
||||
(img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC,
|
||||
)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
|
||||
def get_minarea_rect_crop(img, points):
|
||||
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_a, index_b, index_c, index_d = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_a = 0
|
||||
index_d = 1
|
||||
else:
|
||||
index_a = 1
|
||||
index_d = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_b = 2
|
||||
index_c = 3
|
||||
else:
|
||||
index_b = 3
|
||||
index_c = 2
|
||||
|
||||
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
|
||||
crop_img = get_rotate_crop_image(img, np.array(box))
|
||||
return crop_img
|
||||
|
||||
|
||||
# def check_gpu(use_gpu):
|
||||
# if use_gpu and (
|
||||
# not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
|
||||
# ):
|
||||
# use_gpu = False
|
||||
# return use_gpu
|
||||
|
||||
|
||||
def _check_image_file(path):
|
||||
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"}
|
||||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file, infer_list=None):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name="ppocr", log_file=None, log_level=logging.DEBUG):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified a FileHandler will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
||||
)
|
||||
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
logger_initialized[name] = True
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
def get_rotate_crop_image(img, points):
|
||||
"""
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
"""
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))
|
||||
)
|
||||
img_crop_height = int(
|
||||
max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))
|
||||
)
|
||||
pts_std = np.float32(
|
||||
[
|
||||
[0, 0],
|
||||
[img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height],
|
||||
]
|
||||
)
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M,
|
||||
(img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC,
|
||||
)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
|
||||
def get_minarea_rect_crop(img, points):
|
||||
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_a, index_b, index_c, index_d = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_a = 0
|
||||
index_d = 1
|
||||
else:
|
||||
index_a = 1
|
||||
index_d = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_b = 2
|
||||
index_c = 3
|
||||
else:
|
||||
index_b = 3
|
||||
index_c = 2
|
||||
|
||||
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
|
||||
crop_img = get_rotate_crop_image(img, np.array(box))
|
||||
return crop_img
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -1,24 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
from ..ocr_model.model.TexTeller import TexTeller
|
||||
from ..globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer()
|
||||
|
||||
# Don't forget to config your dataset path in loader.py
|
||||
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
|
||||
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
text_iterator=dataset['latex_formula'],
|
||||
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
|
||||
vocab_size=VOCAB_SIZE,
|
||||
)
|
||||
|
||||
# Save the new tokenizer for later training and inference
|
||||
new_tokenizer.save_pretrained('./your_dir_name')
|
||||
@@ -1 +0,0 @@
|
||||
from .mix_inference import mix_inference
|
||||
@@ -1,261 +0,0 @@
|
||||
import re
|
||||
import heapq
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from ..det_model.inference import predict as latex_det_predict
|
||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||
|
||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
||||
from ..ocr_model.utils.to_katex import to_katex, change_all
|
||||
|
||||
MAXV = 999999999
|
||||
|
||||
|
||||
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
||||
mask_img = img.copy()
|
||||
for bbox in bboxes:
|
||||
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
|
||||
return mask_img
|
||||
|
||||
|
||||
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if len(sorted_bboxes) == 0:
|
||||
return []
|
||||
bboxes = sorted_bboxes.copy()
|
||||
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
||||
bboxes.append(guard)
|
||||
res = []
|
||||
prev = bboxes[0]
|
||||
for curr in bboxes:
|
||||
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
||||
res.append(prev)
|
||||
prev = curr
|
||||
else:
|
||||
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
||||
return res
|
||||
|
||||
|
||||
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if latex_bboxes == []:
|
||||
return ocr_bboxes
|
||||
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
||||
return ocr_bboxes
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(bboxes):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
|
||||
|
||||
assert len(bboxes) > 1
|
||||
|
||||
heapq.heapify(bboxes)
|
||||
res = []
|
||||
candidate = heapq.heappop(bboxes)
|
||||
curr = heapq.heappop(bboxes)
|
||||
idx = 0
|
||||
while len(bboxes) > 0:
|
||||
idx += 1
|
||||
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||
|
||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.ur_point.x < curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
if candidate.label == "text" and curr.label == "text":
|
||||
candidate.w = curr.ur_point.x - candidate.p.x
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.label != curr.label:
|
||||
if candidate.label == "text":
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
curr.w = curr.ur_point.x - candidate.ur_point.x
|
||||
curr.p.x = candidate.ur_point.x
|
||||
heapq.heappush(bboxes, curr)
|
||||
curr = heapq.heappop(bboxes)
|
||||
|
||||
elif candidate.ur_point.x >= curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
|
||||
if candidate.label == "text":
|
||||
assert curr.label != "text"
|
||||
heapq.heappush(
|
||||
bboxes,
|
||||
Bbox(
|
||||
curr.ur_point.x,
|
||||
candidate.p.y,
|
||||
candidate.h,
|
||||
candidate.ur_point.x - curr.ur_point.x,
|
||||
label="text",
|
||||
confidence=candidate.confidence,
|
||||
content=None,
|
||||
),
|
||||
)
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert curr.label == "text"
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert False
|
||||
res.append(candidate)
|
||||
res.append(curr)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(res):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
|
||||
sliced_imgs = []
|
||||
for bbox in ocr_bboxes:
|
||||
x, y = int(bbox.p.x), int(bbox.p.y)
|
||||
w, h = int(bbox.w), int(bbox.h)
|
||||
sliced_img = img[y : y + h, x : x + w]
|
||||
sliced_imgs.append(sliced_img)
|
||||
return sliced_imgs
|
||||
|
||||
|
||||
def mix_inference(
|
||||
img_path: str,
|
||||
infer_config,
|
||||
latex_det_model,
|
||||
lang_ocr_models,
|
||||
latex_rec_models,
|
||||
accelerator="cpu",
|
||||
num_beams=1,
|
||||
) -> str:
|
||||
'''
|
||||
Input a mixed image of formula text and output str (in markdown syntax)
|
||||
'''
|
||||
global img
|
||||
img = cv2.imread(img_path)
|
||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||
|
||||
start_time = time.time()
|
||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
||||
end_time = time.time()
|
||||
print(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
det_model, rec_model = lang_ocr_models
|
||||
start_time = time.time()
|
||||
det_prediction, _ = det_model(masked_img)
|
||||
end_time = time.time()
|
||||
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||
ocr_bboxes = [
|
||||
Bbox(
|
||||
p[0][0],
|
||||
p[0][1],
|
||||
p[3][1] - p[0][1],
|
||||
p[1][0] - p[0][0],
|
||||
label="text",
|
||||
confidence=None,
|
||||
content=None,
|
||||
)
|
||||
for p in det_prediction
|
||||
]
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
|
||||
|
||||
ocr_bboxes = sorted(ocr_bboxes)
|
||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||
|
||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||
start_time = time.time()
|
||||
rec_predictions, _ = rec_model(sliced_imgs)
|
||||
end_time = time.time()
|
||||
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
assert len(rec_predictions) == len(ocr_bboxes)
|
||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||
bbox.content = content[0]
|
||||
|
||||
latex_imgs = []
|
||||
for bbox in latex_bboxes:
|
||||
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
||||
start_time = time.time()
|
||||
latex_rec_res = latex_rec_predict(
|
||||
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||
bbox.content = to_katex(content)
|
||||
if bbox.label == "embedding":
|
||||
bbox.content = " $" + bbox.content + "$ "
|
||||
elif bbox.label == "isolated":
|
||||
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
if bboxes == []:
|
||||
return ""
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
for curr in bboxes:
|
||||
# Add the formula number back to the isolated formula
|
||||
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
||||
curr.content = curr.content.strip()
|
||||
if curr.content.startswith('(') and curr.content.endswith(')'):
|
||||
curr.content = curr.content[1:-1]
|
||||
|
||||
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
|
||||
# in case of multiple tag
|
||||
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
|
||||
else:
|
||||
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
|
||||
continue
|
||||
|
||||
if not prev.same_row(curr):
|
||||
md += " "
|
||||
|
||||
if curr.label == "embedding":
|
||||
# remove the bold effect from inline formulas
|
||||
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
# change split environment into aligned
|
||||
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
|
||||
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
curr.content = re.sub(r' +', ' ', curr.content)
|
||||
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
|
||||
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
|
||||
md += curr.content
|
||||
prev = curr
|
||||
return md.strip()
|
||||