diff --git a/.gitignore b/.gitignore index 789f903..f44176c 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ **/.cache **/tmp* **/data +**/*cache **/ckpt \ No newline at end of file diff --git a/README.md b/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/requirements.txt b/requirements.txt index 791fa24..84041b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,75 +1,13 @@ -absl-py==2.0.0 -accelerate==0.26.0 -aiohttp==3.9.1 -aiosignal==1.3.1 -async-timeout==4.0.3 -attrs==23.2.0 -cachetools==5.3.2 -certifi==2023.11.17 -charset-normalizer==3.3.2 -datasets==2.16.1 -dill==0.3.7 -filelock==3.13.1 -frozenlist==1.4.1 -fsspec==2023.10.0 -google-auth==2.26.2 -google-auth-oauthlib==1.2.0 -grpcio==1.60.0 -huggingface-hub==0.20.2 -idna==3.6 -Jinja2==3.1.2 -Markdown==3.5.2 -MarkupSafe==2.1.3 -mpmath==1.3.0 -multidict==6.0.4 -multiprocess==0.70.15 -networkx==3.2.1 -numpy==1.26.3 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.18.1 -nvidia-nvjitlink-cu12==12.3.101 -nvidia-nvtx-cu12==12.1.105 -oauthlib==3.2.2 -packaging==23.2 -pandas==2.1.4 -pillow==10.2.0 -protobuf==4.23.4 -psutil==5.9.7 -pyarrow==14.0.2 -pyarrow-hotfix==0.6 -pyasn1==0.5.1 -pyasn1-modules==0.3.0 -python-dateutil==2.8.2 -pytz==2023.3.post1 -PyYAML==6.0.1 -regex==2023.12.25 -requests==2.31.0 -requests-oauthlib==1.3.1 -rsa==4.9 -safetensors==0.4.1 -six==1.16.0 -sympy==1.12 -tensorboard==2.15.1 -tensorboard-data-server==0.7.2 -tensorboardX==2.6.2.2 -tokenizers==0.15.0 -torch==2.1.2 -torchaudio==2.1.2 -torchvision==0.16.2 -tqdm==4.66.1 -transformers==4.36.2 -triton==2.1.0 -typing_extensions==4.9.0 -tzdata==2023.4 -urllib3==2.1.0 -Werkzeug==3.0.1 -xxhash==3.4.1 -yarl==1.9.4 +transformers +datasets +evaluate +streamlit +opencv-python +ray[serve] +accelerate +tensorboardX +nltk +python-multipart + +pdf2image +augraphy \ No newline at end of file diff --git a/run.sh b/run.sh deleted file mode 100755 index 64aa4a5..0000000 --- a/run.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -# 设置 CUDA 设备 -export CUDA_VISIBLE_DEVICES=0,1,2,4 - -# 运行 Python 脚本并将输出重定向到日志文件 -nohup python -m src.models.resizer.train.train > train_result_pred_height_v3.log 2>&1 & diff --git a/src/models/ocr_model/train/foo.png b/src/models/ocr_model/train/foo.png deleted file mode 100644 index 61cf525..0000000 Binary files a/src/models/ocr_model/train/foo.png and /dev/null differ diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 3693705..49cdbd2 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -8,14 +8,14 @@ from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrai from .training_args import CONFIG from ..model.TexTeller import TexTeller -from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn, filter_fn +from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn from ..utils.metrics import bleu_metric from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): training_args = TrainingArguments(**CONFIG) - debug_mode = True + debug_mode = False if debug_mode: training_args.auto_find_batch_size = False training_args.num_train_epochs = 2 @@ -88,16 +88,20 @@ if __name__ == '__main__': map_fn = partial(tokenize_fn, tokenizer=tokenizer) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True) - tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + + train_dataset = train_dataset.with_transform(img_train_transform) + eval_dataset = eval_dataset.with_transform(img_inf_transform) + collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # model = TexTeller() model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') # ================= debug ======================= - foo = train_dataset[:3] + # foo = train_dataset[:50] + # bar = eval_dataset[:50] # ================= debug ======================= enable_train = True diff --git a/src/models/ocr_model/train/training_args.py b/src/models/ocr_model/train/training_args.py index 042ad0b..24fe328 100644 --- a/src/models/ocr_model/train/training_args.py +++ b/src/models/ocr_model/train/training_args.py @@ -16,7 +16,7 @@ CONFIG = { #+通常与eval_steps一致 "logging_nan_inf_filter": False, # 对loss=nan或inf进行记录 - "num_train_epochs": 3, # 总的训练轮数 + "num_train_epochs": 4, # 总的训练轮数 # "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数, #+那么num_train_epochs将被忽略(通常用于调试) @@ -25,7 +25,7 @@ CONFIG = { "per_device_train_batch_size": 3, # 每个GPU的batch size "per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size # "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay) - "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay) + "auto_find_batch_size": False, # 自动搜索合适的batch size(指数decay) "optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效) #+当设置了optim后,就不需要在Trainer中传入optimizer @@ -41,8 +41,8 @@ CONFIG = { "gradient_checkpointing": False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间) "label_smoothing_factor": 0.0, # softlabel,等于0时表示未开启 # "debug": "underflow_overflow", # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug) - "jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错) - "torch_compile": True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能) + "jit_mode_eval": False, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错) + "torch_compile": False, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能) #+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来 # "deepspeed": "your_json_path", # 使用deepspeed来训练,需要指定ds_config.json的路径 #+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等) diff --git a/src/models/ocr_model/utils/functional.py b/src/models/ocr_model/utils/functional.py index a2710aa..9cb19ab 100644 --- a/src/models/ocr_model/utils/functional.py +++ b/src/models/ocr_model/utils/functional.py @@ -1,12 +1,8 @@ import torch -from functools import partial -from datasets import load_dataset - from transformers import DataCollatorForLanguageModeling from typing import List, Dict, Any -from .transforms import train_transform -from ..model.TexTeller import TexTeller +from .transforms import train_transform, inference_transform from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE @@ -44,41 +40,20 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[ return batch -def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: +def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: processed_img = train_transform(samples['pixel_values']) samples['pixel_values'] = processed_img return samples +def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + processed_img = inference_transform(samples['pixel_values']) + samples['pixel_values'] = processed_img + return samples + + def filter_fn(sample, tokenizer=None) -> bool: return ( sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10 ) - - -if __name__ == '__main__': - dataset = load_dataset( - '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', - 'cleaned_formulas' - )['train'].select(range(20)) - tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') - - map_fn = partial(tokenize_fn, tokenizer=tokenizer) - collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) - - tokenized_formula = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names) - tokenized_formula = tokenized_formula.to_dict() - # tokenized_formula['pixel_values'] = dataset['image'] - # tokenized_formula = dataset.from_dict(tokenized_formula) - tokenized_dataset = tokenized_formula.with_transform(img_transform_fn) - - dataset_dict = tokenized_dataset[:] - dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())] - batch = collate_fn_with_tokenizer(dataset_list) - - from ..model.TexTeller import TexTeller - model = TexTeller() - out = model(**batch) - - pause = 1 diff --git a/src/models/ocr_model/utils/ocr_aug.py b/src/models/ocr_model/utils/ocr_aug.py index 78bdd48..0e364ca 100644 --- a/src/models/ocr_model/utils/ocr_aug.py +++ b/src/models/ocr_model/utils/ocr_aug.py @@ -7,9 +7,8 @@ def ocr_augmentation_pipeline(): ] ink_phase = [ - # 6ms InkColorSwap( - ink_swap_color="random", + ink_swap_color="lhy_custom", ink_swap_sequence_number_range=(5, 10), ink_swap_min_width_range=(2, 3), ink_swap_max_width_range=(100, 120), @@ -17,87 +16,78 @@ def ocr_augmentation_pipeline(): ink_swap_max_height_range=(100, 120), ink_swap_min_area_range=(10, 20), ink_swap_max_area_range=(400, 500), - p=0.1 + p=0.2 ), - # 10ms - Dithering( - dither=random.choice(["ordered", "floyd-steinberg"]), - order=(3, 5), - p=0.05 + LinesDegradation( + line_roi=(0.0, 0.0, 1.0, 1.0), + line_gradient_range=(32, 255), + line_gradient_direction=(0, 2), + line_split_probability=(0.2, 0.4), + line_replacement_value=(250, 255), + line_min_length=(30, 40), + line_long_to_short_ratio=(5, 7), + line_replacement_probability=(0.4, 0.5), + line_replacement_thickness=(1, 3), + p=0.2 ), - # 10ms - InkBleed( - intensity_range=(0.1, 0.2), - kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]), - severity=(0.4, 0.6), - p=0.2, + + # ============================ + OneOf( + [ + Dithering( + dither="floyd-steinberg", + order=(3, 5), + ), + InkBleed( + intensity_range=(0.1, 0.2), + kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]), + severity=(0.4, 0.6), + ), + ], + p=0.2 ), - # 40ms + # ============================ + + # ============================ InkShifter( text_shift_scale_range=(18, 27), text_shift_factor_range=(1, 4), text_fade_range=(0, 2), blur_kernel_size=(5, 5), blur_sigma=0, - noise_type="random", - p=0.1 + noise_type="perlin", + p=0.2 ), - # 90ms - # Letterpress( - # n_samples=(100, 400), - # n_clusters=(200, 400), - # std_range=(500, 3000), - # value_range=(150, 224), - # value_threshold_range=(96, 128), - # blur=1, - # p=0.1 - # ), + # ============================ + ] paper_phase = [ - # 50ms - # OneOf( - # [ - # ColorPaper( - # hue_range=(0, 255), - # saturation_range=(10, 40), - # ), - # PatternGenerator( - # imgx=random.randint(256, 512), - # imgy=random.randint(256, 512), - # n_rotation_range=(10, 15), - # color="random", - # alpha_range=(0.25, 0.5), - # ), - # NoiseTexturize( - # sigma_range=(3, 10), - # turbulence_range=(2, 5), - # texture_width_range=(300, 500), - # texture_height_range=(300, 500), - # ), - # ], - # p=0.05 - # ), - # 10ms - BrightnessTexturize( + NoiseTexturize( # tested + sigma_range=(3, 10), + turbulence_range=(2, 5), + texture_width_range=(300, 500), + texture_height_range=(300, 500), + p=0.2 + ), + BrightnessTexturize( # tested texturize_range=(0.9, 0.99), deviation=0.03, - p=0.1 + p=0.2 ) ] post_phase = [ - # 13ms - ColorShift( + ColorShift( # tested color_shift_offset_x_range=(3, 5), color_shift_offset_y_range=(3, 5), color_shift_iterations=(2, 3), color_shift_brightness_range=(0.9, 1.1), color_shift_gaussian_kernel_range=(3, 3), - p=0.05 + p=0.2 ), - # 13ms - DirtyDrum( + + DirtyDrum( # tested line_width_range=(1, 6), line_concentration=random.uniform(0.05, 0.15), direction=random.randint(0, 2), @@ -105,9 +95,10 @@ def ocr_augmentation_pipeline(): noise_value=(64, 224), ksize=random.choice([(3, 3), (5, 5), (7, 7)]), sigmaX=0, - p=0.05, + p=0.2 ), - # 10ms + + # ===================================== OneOf( [ LightingGradient( @@ -128,121 +119,23 @@ def ocr_augmentation_pipeline(): gamma_range=(0.9, 1.1), ), ], - p=0.05 + p=0.2 ), - # 6ms - Jpeg( - quality_range=(25, 95), - p=0.1 - ), - # 12ms - Markup( - num_lines_range=(2, 7), - markup_length_range=(0.5, 1), - markup_thickness_range=(1, 2), - markup_type=random.choice(["strikethrough", "crossed", "highlight", "underline"]), - markup_color="random", - single_word_mode=False, - repetitions=1, - p=0.05 - ), - # 65ms - # OneOf( - # [ - # BadPhotoCopy( - # noise_mask=None, - # noise_type=-1, - # noise_side="random", - # noise_iteration=(1, 2), - # noise_size=(1, 3), - # noise_value=(128, 196), - # noise_sparsity=(0.3, 0.6), - # noise_concentration=(0.1, 0.6), - # blur_noise=random.choice([True, False]), - # blur_noise_kernel=random.choice([(3, 3), (5, 5), (7, 7)]), - # wave_pattern=random.choice([True, False]), - # edge_effect=random.choice([True, False]), - # ), - # ShadowCast( - # shadow_side="random", - # shadow_vertices_range=(1, 20), - # shadow_width_range=(0.3, 0.8), - # shadow_height_range=(0.3, 0.8), - # shadow_color=(0, 0, 0), - # shadow_opacity_range=(0.2, 0.9), - # shadow_iterations_range=(1, 2), - # shadow_blur_kernel_range=(101, 301), - # ), - # LowLightNoise( - # num_photons_range=(50, 100), - # alpha_range=(0.7, 1.0), - # beta_range=(10, 30), - # gamma_range=(1, 1.8), - # bias_range=(20, 40), - # dark_current_value=1.0, - # exposure_time=0.2, - # gain=0.1, - # ), - # ], - # p=0.05, - # ), - # 10ms + # ===================================== + + # ===================================== OneOf( [ - NoisyLines( - noisy_lines_direction="random", - noisy_lines_location="random", - noisy_lines_number_range=(5, 20), - noisy_lines_color=(0, 0, 0), - noisy_lines_thickness_range=(1, 2), - noisy_lines_random_noise_intensity_range=(0.01, 0.1), - noisy_lines_length_interval_range=(0, 100), - noisy_lines_gaussian_kernel_value_range=(3, 5), - noisy_lines_overlay_method="ink_to_paper", + SubtleNoise( + subtle_range=random.randint(5, 10), ), - BindingsAndFasteners( - overlay_types="darken", - foreground=None, - effect_type="random", - width_range="random", - height_range="random", - angle_range=(-30, 30), - ntimes=(2, 6), - nscales=(0.9, 1.0), - edge="random", - edge_offset=(10, 50), - use_figshare_library=0, + Jpeg( + quality_range=(85, 95), ), ], - p=0.05, - ), - # 20ms - OneOf( - [ - PageBorder( - page_border_width_height="random", - page_border_color=(0, 0, 0), - page_border_background_color=(0, 0, 0), - page_numbers="random", - page_rotation_angle_range=(-3, 3), - curve_frequency=(2, 8), - curve_height=(2, 4), - curve_length_one_side=(50, 100), - same_page_border=random.choice([0, 1]), - ), - Folding( - fold_x=None, - fold_deviation=(0, 0), - fold_count=random.randint(2, 8), - fold_noise=0.01, - fold_angle_range=(-360, 360), - gradient_width=(0.1, 0.2), - gradient_height=(0.01, 0.02), - backdrop_color=(0, 0, 0), - ), - ], - p=0.05 + p=0.2 ), + # ===================================== ] pipeline = AugraphyPipeline( @@ -250,7 +143,7 @@ def ocr_augmentation_pipeline(): paper_phase=paper_phase, post_phase=post_phase, pre_phase=pre_phase, - log=False, + log=False ) return pipeline \ No newline at end of file diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index 8ce3bd7..0fc862d 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -4,7 +4,7 @@ import numpy as np import cv2 from torchvision.transforms import v2 -from typing import List, Union +from typing import List from PIL import Image from ...globals import ( @@ -77,7 +77,6 @@ def trim_white_border(image: np.ndarray): return trimmed_image -# BUGY CODE def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: randi = [random.randint(0, max_size) for _ in range(4)] pad_height_size = randi[1] + randi[3] @@ -131,9 +130,38 @@ def random_resize( ] +def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray: + # Get the center of the image to define the point of rotation + image_center = tuple(np.array(image.shape[1::-1]) / 2) + + # Generate a random angle within the specified range + angle = random.randint(min_angle, max_angle) + + # Get the rotation matrix for rotating the image around its center + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # Determine the size of the rotated image + cos = np.abs(rotation_mat[0, 0]) + sin = np.abs(rotation_mat[0, 1]) + new_width = int((image.shape[0] * sin) + (image.shape[1] * cos)) + new_height = int((image.shape[0] * cos) + (image.shape[1] * sin)) + + # Adjust the rotation matrix to take into account translation + rotation_mat[0, 2] += (new_width / 2) - image_center[0] + rotation_mat[1, 2] += (new_height / 2) - image_center[1] + + # Rotate the image with the specified border color (white in this case) + rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)) + + return rotated_image + + def ocr_aug(image: np.ndarray) -> np.ndarray: + # 20%的概率进行随机旋转 + if random.random() < 0.2: + image = rotate(image, -5, 5) # 增加白边 - image = add_white_border(image, max_size=35).permute(1, 2, 0).numpy() + image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy() # 数据增强 image = train_pipeline(image) return image @@ -149,10 +177,7 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: # 裁剪掉白边 images = [trim_white_border(image) for image in images] - # 增加白边 - # images = [add_white_border(image, max_size=35) for image in images] - # 数据增强 - # images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images] + # OCR augmentation images = [ocr_aug(image) for image in images] # general transform pipeline @@ -165,10 +190,11 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now" + images = [np.array(img.convert('RGB')) for img in images] # 裁剪掉白边 images = [trim_white_border(image) for image in images] # general transform pipeline - images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] + images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image] # padding to fixed size images = padding(images, OCR_IMG_SIZE) @@ -190,8 +216,6 @@ if __name__ == '__main__': ] imgs_path = [str(img_path) for img_path in imgs_path] imgs = convert2rgb(imgs_path) - # res = train_transform(imgs) - # res = [v2.functional.to_pil_image(img) for img in res] res = random_resize(imgs, 0.5, 1.5) pause = 1