写好了v3版本的训练代码(v3版本加入了自然场景训练增强)

This commit is contained in:
三洋三洋
2024-03-28 10:19:40 +00:00
parent fb2ab8230d
commit e8967dce0f
10 changed files with 130 additions and 302 deletions

1
.gitignore vendored
View File

@@ -6,4 +6,5 @@
**/.cache **/.cache
**/tmp* **/tmp*
**/data **/data
**/*cache
**/ckpt **/ckpt

View File

View File

@@ -1,75 +1,13 @@
absl-py==2.0.0 transformers
accelerate==0.26.0 datasets
aiohttp==3.9.1 evaluate
aiosignal==1.3.1 streamlit
async-timeout==4.0.3 opencv-python
attrs==23.2.0 ray[serve]
cachetools==5.3.2 accelerate
certifi==2023.11.17 tensorboardX
charset-normalizer==3.3.2 nltk
datasets==2.16.1 python-multipart
dill==0.3.7
filelock==3.13.1 pdf2image
frozenlist==1.4.1 augraphy
fsspec==2023.10.0
google-auth==2.26.2
google-auth-oauthlib==1.2.0
grpcio==1.60.0
huggingface-hub==0.20.2
idna==3.6
Jinja2==3.1.2
Markdown==3.5.2
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
packaging==23.2
pandas==2.1.4
pillow==10.2.0
protobuf==4.23.4
psutil==5.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
safetensors==0.4.1
six==1.16.0
sympy==1.12
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tokenizers==0.15.0
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.36.2
triton==2.1.0
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.1.0
Werkzeug==3.0.1
xxhash==3.4.1
yarl==1.9.4

7
run.sh
View File

@@ -1,7 +0,0 @@
#!/usr/bin/env bash
# 设置 CUDA 设备
export CUDA_VISIBLE_DEVICES=0,1,2,4
# 运行 Python 脚本并将输出重定向到日志文件
nohup python -m src.models.resizer.train.train > train_result_pred_height_v3.log 2>&1 &

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -8,14 +8,14 @@ from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrai
from .training_args import CONFIG from .training_args import CONFIG
from ..model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn, filter_fn from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
from ..utils.metrics import bleu_metric from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG) training_args = TrainingArguments(**CONFIG)
debug_mode = True debug_mode = False
if debug_mode: if debug_mode:
training_args.auto_find_batch_size = False training_args.auto_find_batch_size = False
training_args.num_train_epochs = 2 training_args.num_train_epochs = 2
@@ -88,16 +88,20 @@ if __name__ == '__main__':
map_fn = partial(tokenize_fn, tokenizer=tokenizer) map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller() # model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
# ================= debug ======================= # ================= debug =======================
foo = train_dataset[:3] # foo = train_dataset[:50]
# bar = eval_dataset[:50]
# ================= debug ======================= # ================= debug =======================
enable_train = True enable_train = True

View File

@@ -16,7 +16,7 @@ CONFIG = {
#+通常与eval_steps一致 #+通常与eval_steps一致
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录 "logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
"num_train_epochs": 3, # 总的训练轮数 "num_train_epochs": 4, # 总的训练轮数
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数, # "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
#+那么num_train_epochs将被忽略通常用于调试 #+那么num_train_epochs将被忽略通常用于调试
@@ -25,7 +25,7 @@ CONFIG = {
"per_device_train_batch_size": 3, # 每个GPU的batch size "per_device_train_batch_size": 3, # 每个GPU的batch size
"per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size "per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size
# "auto_find_batch_size": True, # 自动搜索合适的batch size指数decay # "auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
"auto_find_batch_size": True, # 自动搜索合适的batch size指数decay "auto_find_batch_size": False, # 自动搜索合适的batch size指数decay
"optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效 "optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效
#+当设置了optim后就不需要在Trainer中传入optimizer #+当设置了optim后就不需要在Trainer中传入optimizer
@@ -41,8 +41,8 @@ CONFIG = {
"gradient_checkpointing": False, # 当为True时会在forward时适当丢弃一些中间量用于backward从而减轻显存压力但会增加forward的时间 "gradient_checkpointing": False, # 当为True时会在forward时适当丢弃一些中间量用于backward从而减轻显存压力但会增加forward的时间
"label_smoothing_factor": 0.0, # softlabel等于0时表示未开启 "label_smoothing_factor": 0.0, # softlabel等于0时表示未开启
# "debug": "underflow_overflow", # 训练时检查溢出如果发生则会发出警告。该模式通常用于debug # "debug": "underflow_overflow", # 训练时检查溢出如果发生则会发出警告。该模式通常用于debug
"jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错 "jit_mode_eval": False, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错
"torch_compile": True, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能 "torch_compile": False, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来 #+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
# "deepspeed": "your_json_path", # 使用deepspeed来训练需要指定ds_config.json的路径 # "deepspeed": "your_json_path", # 使用deepspeed来训练需要指定ds_config.json的路径
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致如学习率batch size梯度累积步数等 #+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致如学习率batch size梯度累积步数等

View File

@@ -1,12 +1,8 @@
import torch import torch
from functools import partial
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any from typing import List, Dict, Any
from .transforms import train_transform from .transforms import train_transform, inference_transform
from ..model.TexTeller import TexTeller
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
@@ -44,41 +40,20 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
return batch return batch
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values']) processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img samples['pixel_values'] = processed_img
return samples return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool: def filter_fn(sample, tokenizer=None) -> bool:
return ( return (
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10 and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
) )
if __name__ == '__main__':
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train'].select(range(20))
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
tokenized_formula = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names)
tokenized_formula = tokenized_formula.to_dict()
# tokenized_formula['pixel_values'] = dataset['image']
# tokenized_formula = dataset.from_dict(tokenized_formula)
tokenized_dataset = tokenized_formula.with_transform(img_transform_fn)
dataset_dict = tokenized_dataset[:]
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
batch = collate_fn_with_tokenizer(dataset_list)
from ..model.TexTeller import TexTeller
model = TexTeller()
out = model(**batch)
pause = 1

View File

@@ -7,9 +7,8 @@ def ocr_augmentation_pipeline():
] ]
ink_phase = [ ink_phase = [
# 6ms
InkColorSwap( InkColorSwap(
ink_swap_color="random", ink_swap_color="lhy_custom",
ink_swap_sequence_number_range=(5, 10), ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3), ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120), ink_swap_max_width_range=(100, 120),
@@ -17,87 +16,78 @@ def ocr_augmentation_pipeline():
ink_swap_max_height_range=(100, 120), ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20), ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500), ink_swap_max_area_range=(400, 500),
p=0.1 p=0.2
), ),
# 10ms LinesDegradation(
Dithering( line_roi=(0.0, 0.0, 1.0, 1.0),
dither=random.choice(["ordered", "floyd-steinberg"]), line_gradient_range=(32, 255),
order=(3, 5), line_gradient_direction=(0, 2),
p=0.05 line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
p=0.2
), ),
# 10ms
InkBleed( # ============================
intensity_range=(0.1, 0.2), OneOf(
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]), [
severity=(0.4, 0.6), Dithering(
p=0.2, dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
p=0.2
), ),
# 40ms # ============================
# ============================
InkShifter( InkShifter(
text_shift_scale_range=(18, 27), text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4), text_shift_factor_range=(1, 4),
text_fade_range=(0, 2), text_fade_range=(0, 2),
blur_kernel_size=(5, 5), blur_kernel_size=(5, 5),
blur_sigma=0, blur_sigma=0,
noise_type="random", noise_type="perlin",
p=0.1 p=0.2
), ),
# 90ms # ============================
# Letterpress(
# n_samples=(100, 400),
# n_clusters=(200, 400),
# std_range=(500, 3000),
# value_range=(150, 224),
# value_threshold_range=(96, 128),
# blur=1,
# p=0.1
# ),
] ]
paper_phase = [ paper_phase = [
# 50ms NoiseTexturize( # tested
# OneOf( sigma_range=(3, 10),
# [ turbulence_range=(2, 5),
# ColorPaper( texture_width_range=(300, 500),
# hue_range=(0, 255), texture_height_range=(300, 500),
# saturation_range=(10, 40), p=0.2
# ), ),
# PatternGenerator( BrightnessTexturize( # tested
# imgx=random.randint(256, 512),
# imgy=random.randint(256, 512),
# n_rotation_range=(10, 15),
# color="random",
# alpha_range=(0.25, 0.5),
# ),
# NoiseTexturize(
# sigma_range=(3, 10),
# turbulence_range=(2, 5),
# texture_width_range=(300, 500),
# texture_height_range=(300, 500),
# ),
# ],
# p=0.05
# ),
# 10ms
BrightnessTexturize(
texturize_range=(0.9, 0.99), texturize_range=(0.9, 0.99),
deviation=0.03, deviation=0.03,
p=0.1 p=0.2
) )
] ]
post_phase = [ post_phase = [
# 13ms ColorShift( # tested
ColorShift(
color_shift_offset_x_range=(3, 5), color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5), color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3), color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1), color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3), color_shift_gaussian_kernel_range=(3, 3),
p=0.05 p=0.2
), ),
# 13ms
DirtyDrum( DirtyDrum( # tested
line_width_range=(1, 6), line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15), line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2), direction=random.randint(0, 2),
@@ -105,9 +95,10 @@ def ocr_augmentation_pipeline():
noise_value=(64, 224), noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]), ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0, sigmaX=0,
p=0.05, p=0.2
), ),
# 10ms
# =====================================
OneOf( OneOf(
[ [
LightingGradient( LightingGradient(
@@ -128,121 +119,23 @@ def ocr_augmentation_pipeline():
gamma_range=(0.9, 1.1), gamma_range=(0.9, 1.1),
), ),
], ],
p=0.05 p=0.2
), ),
# 6ms # =====================================
Jpeg(
quality_range=(25, 95), # =====================================
p=0.1
),
# 12ms
Markup(
num_lines_range=(2, 7),
markup_length_range=(0.5, 1),
markup_thickness_range=(1, 2),
markup_type=random.choice(["strikethrough", "crossed", "highlight", "underline"]),
markup_color="random",
single_word_mode=False,
repetitions=1,
p=0.05
),
# 65ms
# OneOf(
# [
# BadPhotoCopy(
# noise_mask=None,
# noise_type=-1,
# noise_side="random",
# noise_iteration=(1, 2),
# noise_size=(1, 3),
# noise_value=(128, 196),
# noise_sparsity=(0.3, 0.6),
# noise_concentration=(0.1, 0.6),
# blur_noise=random.choice([True, False]),
# blur_noise_kernel=random.choice([(3, 3), (5, 5), (7, 7)]),
# wave_pattern=random.choice([True, False]),
# edge_effect=random.choice([True, False]),
# ),
# ShadowCast(
# shadow_side="random",
# shadow_vertices_range=(1, 20),
# shadow_width_range=(0.3, 0.8),
# shadow_height_range=(0.3, 0.8),
# shadow_color=(0, 0, 0),
# shadow_opacity_range=(0.2, 0.9),
# shadow_iterations_range=(1, 2),
# shadow_blur_kernel_range=(101, 301),
# ),
# LowLightNoise(
# num_photons_range=(50, 100),
# alpha_range=(0.7, 1.0),
# beta_range=(10, 30),
# gamma_range=(1, 1.8),
# bias_range=(20, 40),
# dark_current_value=1.0,
# exposure_time=0.2,
# gain=0.1,
# ),
# ],
# p=0.05,
# ),
# 10ms
OneOf( OneOf(
[ [
NoisyLines( SubtleNoise(
noisy_lines_direction="random", subtle_range=random.randint(5, 10),
noisy_lines_location="random",
noisy_lines_number_range=(5, 20),
noisy_lines_color=(0, 0, 0),
noisy_lines_thickness_range=(1, 2),
noisy_lines_random_noise_intensity_range=(0.01, 0.1),
noisy_lines_length_interval_range=(0, 100),
noisy_lines_gaussian_kernel_value_range=(3, 5),
noisy_lines_overlay_method="ink_to_paper",
), ),
BindingsAndFasteners( Jpeg(
overlay_types="darken", quality_range=(85, 95),
foreground=None,
effect_type="random",
width_range="random",
height_range="random",
angle_range=(-30, 30),
ntimes=(2, 6),
nscales=(0.9, 1.0),
edge="random",
edge_offset=(10, 50),
use_figshare_library=0,
), ),
], ],
p=0.05, p=0.2
),
# 20ms
OneOf(
[
PageBorder(
page_border_width_height="random",
page_border_color=(0, 0, 0),
page_border_background_color=(0, 0, 0),
page_numbers="random",
page_rotation_angle_range=(-3, 3),
curve_frequency=(2, 8),
curve_height=(2, 4),
curve_length_one_side=(50, 100),
same_page_border=random.choice([0, 1]),
),
Folding(
fold_x=None,
fold_deviation=(0, 0),
fold_count=random.randint(2, 8),
fold_noise=0.01,
fold_angle_range=(-360, 360),
gradient_width=(0.1, 0.2),
gradient_height=(0.01, 0.02),
backdrop_color=(0, 0, 0),
),
],
p=0.05
), ),
# =====================================
] ]
pipeline = AugraphyPipeline( pipeline = AugraphyPipeline(
@@ -250,7 +143,7 @@ def ocr_augmentation_pipeline():
paper_phase=paper_phase, paper_phase=paper_phase,
post_phase=post_phase, post_phase=post_phase,
pre_phase=pre_phase, pre_phase=pre_phase,
log=False, log=False
) )
return pipeline return pipeline

View File

@@ -4,7 +4,7 @@ import numpy as np
import cv2 import cv2
from torchvision.transforms import v2 from torchvision.transforms import v2
from typing import List, Union from typing import List
from PIL import Image from PIL import Image
from ...globals import ( from ...globals import (
@@ -77,7 +77,6 @@ def trim_white_border(image: np.ndarray):
return trimmed_image return trimmed_image
# BUGY CODE
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)] randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3] pad_height_size = randi[1] + randi[3]
@@ -131,9 +130,38 @@ def random_resize(
] ]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray: def ocr_aug(image: np.ndarray) -> np.ndarray:
# 20%的概率进行随机旋转
if random.random() < 0.2:
image = rotate(image, -5, 5)
# 增加白边 # 增加白边
image = add_white_border(image, max_size=35).permute(1, 2, 0).numpy() image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
# 数据增强 # 数据增强
image = train_pipeline(image) image = train_pipeline(image)
return image return image
@@ -149,10 +177,7 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
# 裁剪掉白边 # 裁剪掉白边
images = [trim_white_border(image) for image in images] images = [trim_white_border(image) for image in images]
# 增加白边 # OCR augmentation
# images = [add_white_border(image, max_size=35) for image in images]
# 数据增强
# images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images]
images = [ocr_aug(image) for image in images] images = [ocr_aug(image) for image in images]
# general transform pipeline # general transform pipeline
@@ -165,10 +190,11 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
images = [np.array(img.convert('RGB')) for img in images]
# 裁剪掉白边 # 裁剪掉白边
images = [trim_white_border(image) for image in images] images = [trim_white_border(image) for image in images]
# general transform pipeline # general transform pipeline
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size # padding to fixed size
images = padding(images, OCR_IMG_SIZE) images = padding(images, OCR_IMG_SIZE)
@@ -190,8 +216,6 @@ if __name__ == '__main__':
] ]
imgs_path = [str(img_path) for img_path in imgs_path] imgs_path = [str(img_path) for img_path in imgs_path]
imgs = convert2rgb(imgs_path) imgs = convert2rgb(imgs_path)
# res = train_transform(imgs)
# res = [v2.functional.to_pil_image(img) for img in res]
res = random_resize(imgs, 0.5, 1.5) res = random_resize(imgs, 0.5, 1.5)
pause = 1 pause = 1