写好了v3版本的训练代码(v3版本加入了自然场景训练增强)

This commit is contained in:
三洋三洋
2024-03-28 10:19:40 +00:00
parent fb2ab8230d
commit e8967dce0f
10 changed files with 130 additions and 302 deletions

1
.gitignore vendored
View File

@@ -6,4 +6,5 @@
**/.cache
**/tmp*
**/data
**/*cache
**/ckpt

View File

View File

@@ -1,75 +1,13 @@
absl-py==2.0.0
accelerate==0.26.0
aiohttp==3.9.1
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
datasets==2.16.1
dill==0.3.7
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.10.0
google-auth==2.26.2
google-auth-oauthlib==1.2.0
grpcio==1.60.0
huggingface-hub==0.20.2
idna==3.6
Jinja2==3.1.2
Markdown==3.5.2
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
packaging==23.2
pandas==2.1.4
pillow==10.2.0
protobuf==4.23.4
psutil==5.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
safetensors==0.4.1
six==1.16.0
sympy==1.12
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tokenizers==0.15.0
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.36.2
triton==2.1.0
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.1.0
Werkzeug==3.0.1
xxhash==3.4.1
yarl==1.9.4
transformers
datasets
evaluate
streamlit
opencv-python
ray[serve]
accelerate
tensorboardX
nltk
python-multipart
pdf2image
augraphy

7
run.sh
View File

@@ -1,7 +0,0 @@
#!/usr/bin/env bash
# 设置 CUDA 设备
export CUDA_VISIBLE_DEVICES=0,1,2,4
# 运行 Python 脚本并将输出重定向到日志文件
nohup python -m src.models.resizer.train.train > train_result_pred_height_v3.log 2>&1 &

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -8,14 +8,14 @@ from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrai
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn, filter_fn
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
debug_mode = True
debug_mode = False
if debug_mode:
training_args.auto_find_batch_size = False
training_args.num_train_epochs = 2
@@ -88,16 +88,20 @@ if __name__ == '__main__':
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
# ================= debug =======================
foo = train_dataset[:3]
# foo = train_dataset[:50]
# bar = eval_dataset[:50]
# ================= debug =======================
enable_train = True

View File

@@ -16,7 +16,7 @@ CONFIG = {
#+通常与eval_steps一致
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
"num_train_epochs": 3, # 总的训练轮数
"num_train_epochs": 4, # 总的训练轮数
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
#+那么num_train_epochs将被忽略通常用于调试
@@ -25,7 +25,7 @@ CONFIG = {
"per_device_train_batch_size": 3, # 每个GPU的batch size
"per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size
# "auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
"auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
"auto_find_batch_size": False, # 自动搜索合适的batch size指数decay
"optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效
#+当设置了optim后就不需要在Trainer中传入optimizer
@@ -41,8 +41,8 @@ CONFIG = {
"gradient_checkpointing": False, # 当为True时会在forward时适当丢弃一些中间量用于backward从而减轻显存压力但会增加forward的时间
"label_smoothing_factor": 0.0, # softlabel等于0时表示未开启
# "debug": "underflow_overflow", # 训练时检查溢出如果发生则会发出警告。该模式通常用于debug
"jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错
"torch_compile": True, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能
"jit_mode_eval": False, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错
"torch_compile": False, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
# "deepspeed": "your_json_path", # 使用deepspeed来训练需要指定ds_config.json的路径
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致如学习率batch size梯度累积步数等

View File

@@ -1,12 +1,8 @@
import torch
from functools import partial
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform
from ..model.TexTeller import TexTeller
from .transforms import train_transform, inference_transform
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
@@ -44,41 +40,20 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
return batch
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)
if __name__ == '__main__':
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train'].select(range(20))
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
tokenized_formula = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names)
tokenized_formula = tokenized_formula.to_dict()
# tokenized_formula['pixel_values'] = dataset['image']
# tokenized_formula = dataset.from_dict(tokenized_formula)
tokenized_dataset = tokenized_formula.with_transform(img_transform_fn)
dataset_dict = tokenized_dataset[:]
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
batch = collate_fn_with_tokenizer(dataset_list)
from ..model.TexTeller import TexTeller
model = TexTeller()
out = model(**batch)
pause = 1

View File

@@ -7,9 +7,8 @@ def ocr_augmentation_pipeline():
]
ink_phase = [
# 6ms
InkColorSwap(
ink_swap_color="random",
ink_swap_color="lhy_custom",
ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120),
@@ -17,87 +16,78 @@ def ocr_augmentation_pipeline():
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
p=0.1
p=0.2
),
# 10ms
Dithering(
dither=random.choice(["ordered", "floyd-steinberg"]),
order=(3, 5),
p=0.05
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
line_gradient_range=(32, 255),
line_gradient_direction=(0, 2),
line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
p=0.2
),
# 10ms
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
p=0.2,
# ============================
OneOf(
[
Dithering(
dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
p=0.2
),
# 40ms
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4),
text_fade_range=(0, 2),
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="random",
p=0.1
noise_type="perlin",
p=0.2
),
# 90ms
# Letterpress(
# n_samples=(100, 400),
# n_clusters=(200, 400),
# std_range=(500, 3000),
# value_range=(150, 224),
# value_threshold_range=(96, 128),
# blur=1,
# p=0.1
# ),
# ============================
]
paper_phase = [
# 50ms
# OneOf(
# [
# ColorPaper(
# hue_range=(0, 255),
# saturation_range=(10, 40),
# ),
# PatternGenerator(
# imgx=random.randint(256, 512),
# imgy=random.randint(256, 512),
# n_rotation_range=(10, 15),
# color="random",
# alpha_range=(0.25, 0.5),
# ),
# NoiseTexturize(
# sigma_range=(3, 10),
# turbulence_range=(2, 5),
# texture_width_range=(300, 500),
# texture_height_range=(300, 500),
# ),
# ],
# p=0.05
# ),
# 10ms
BrightnessTexturize(
NoiseTexturize( # tested
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
p=0.2
),
BrightnessTexturize( # tested
texturize_range=(0.9, 0.99),
deviation=0.03,
p=0.1
p=0.2
)
]
post_phase = [
# 13ms
ColorShift(
ColorShift( # tested
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
p=0.05
p=0.2
),
# 13ms
DirtyDrum(
DirtyDrum( # tested
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
@@ -105,9 +95,10 @@ def ocr_augmentation_pipeline():
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
p=0.05,
p=0.2
),
# 10ms
# =====================================
OneOf(
[
LightingGradient(
@@ -128,121 +119,23 @@ def ocr_augmentation_pipeline():
gamma_range=(0.9, 1.1),
),
],
p=0.05
p=0.2
),
# 6ms
Jpeg(
quality_range=(25, 95),
p=0.1
),
# 12ms
Markup(
num_lines_range=(2, 7),
markup_length_range=(0.5, 1),
markup_thickness_range=(1, 2),
markup_type=random.choice(["strikethrough", "crossed", "highlight", "underline"]),
markup_color="random",
single_word_mode=False,
repetitions=1,
p=0.05
),
# 65ms
# OneOf(
# [
# BadPhotoCopy(
# noise_mask=None,
# noise_type=-1,
# noise_side="random",
# noise_iteration=(1, 2),
# noise_size=(1, 3),
# noise_value=(128, 196),
# noise_sparsity=(0.3, 0.6),
# noise_concentration=(0.1, 0.6),
# blur_noise=random.choice([True, False]),
# blur_noise_kernel=random.choice([(3, 3), (5, 5), (7, 7)]),
# wave_pattern=random.choice([True, False]),
# edge_effect=random.choice([True, False]),
# ),
# ShadowCast(
# shadow_side="random",
# shadow_vertices_range=(1, 20),
# shadow_width_range=(0.3, 0.8),
# shadow_height_range=(0.3, 0.8),
# shadow_color=(0, 0, 0),
# shadow_opacity_range=(0.2, 0.9),
# shadow_iterations_range=(1, 2),
# shadow_blur_kernel_range=(101, 301),
# ),
# LowLightNoise(
# num_photons_range=(50, 100),
# alpha_range=(0.7, 1.0),
# beta_range=(10, 30),
# gamma_range=(1, 1.8),
# bias_range=(20, 40),
# dark_current_value=1.0,
# exposure_time=0.2,
# gain=0.1,
# ),
# ],
# p=0.05,
# ),
# 10ms
# =====================================
# =====================================
OneOf(
[
NoisyLines(
noisy_lines_direction="random",
noisy_lines_location="random",
noisy_lines_number_range=(5, 20),
noisy_lines_color=(0, 0, 0),
noisy_lines_thickness_range=(1, 2),
noisy_lines_random_noise_intensity_range=(0.01, 0.1),
noisy_lines_length_interval_range=(0, 100),
noisy_lines_gaussian_kernel_value_range=(3, 5),
noisy_lines_overlay_method="ink_to_paper",
SubtleNoise(
subtle_range=random.randint(5, 10),
),
BindingsAndFasteners(
overlay_types="darken",
foreground=None,
effect_type="random",
width_range="random",
height_range="random",
angle_range=(-30, 30),
ntimes=(2, 6),
nscales=(0.9, 1.0),
edge="random",
edge_offset=(10, 50),
use_figshare_library=0,
Jpeg(
quality_range=(85, 95),
),
],
p=0.05,
),
# 20ms
OneOf(
[
PageBorder(
page_border_width_height="random",
page_border_color=(0, 0, 0),
page_border_background_color=(0, 0, 0),
page_numbers="random",
page_rotation_angle_range=(-3, 3),
curve_frequency=(2, 8),
curve_height=(2, 4),
curve_length_one_side=(50, 100),
same_page_border=random.choice([0, 1]),
),
Folding(
fold_x=None,
fold_deviation=(0, 0),
fold_count=random.randint(2, 8),
fold_noise=0.01,
fold_angle_range=(-360, 360),
gradient_width=(0.1, 0.2),
gradient_height=(0.01, 0.02),
backdrop_color=(0, 0, 0),
),
],
p=0.05
p=0.2
),
# =====================================
]
pipeline = AugraphyPipeline(
@@ -250,7 +143,7 @@ def ocr_augmentation_pipeline():
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False,
log=False
)
return pipeline

View File

@@ -4,7 +4,7 @@ import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List, Union
from typing import List
from PIL import Image
from ...globals import (
@@ -77,7 +77,6 @@ def trim_white_border(image: np.ndarray):
return trimmed_image
# BUGY CODE
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
@@ -131,9 +130,38 @@ def random_resize(
]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
# 20%的概率进行随机旋转
if random.random() < 0.2:
image = rotate(image, -5, 5)
# 增加白边
image = add_white_border(image, max_size=35).permute(1, 2, 0).numpy()
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
# 数据增强
image = train_pipeline(image)
return image
@@ -149,10 +177,7 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# 增加白边
# images = [add_white_border(image, max_size=35) for image in images]
# 数据增强
# images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images]
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
@@ -165,10 +190,11 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
images = [np.array(img.convert('RGB')) for img in images]
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, OCR_IMG_SIZE)
@@ -190,8 +216,6 @@ if __name__ == '__main__':
]
imgs_path = [str(img_path) for img_path in imgs_path]
imgs = convert2rgb(imgs_path)
# res = train_transform(imgs)
# res = [v2.functional.to_pil_image(img) for img in res]
res = random_resize(imgs, 0.5, 1.5)
pause = 1