diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index adad913..24a9323 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -41,6 +41,7 @@ class TexTeller(VisionEncoderDecoderModel): if __name__ == "__main__": + pause = 1 # texteller = TexTeller() # from ..utils.inference import inference # model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 0a1a039..3693705 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -15,7 +15,7 @@ from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): training_args = TrainingArguments(**CONFIG) - debug_mode = False + debug_mode = True if debug_mode: training_args.auto_find_batch_size = False training_args.num_train_epochs = 2 @@ -96,6 +96,10 @@ if __name__ == '__main__': # model = TexTeller() model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') + # ================= debug ======================= + foo = train_dataset[:3] + # ================= debug ======================= + enable_train = True enable_evaluate = True if enable_train: diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index 5bf21a0..9bfaa89 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -5,7 +5,6 @@ import cv2 from torchvision.transforms import v2 from typing import List, Union -from augraphy import * from PIL import Image from ...globals import ( @@ -15,8 +14,10 @@ from ...globals import ( IMAGE_MEAN, IMAGE_STD, MAX_RESIZE_RATIO, MIN_RESIZE_RATIO ) +from .ocr_aug import ocr_augmentation_pipeline -train_pipeline = default_augraphy_pipeline(scan_only=True) +# train_pipeline = default_augraphy_pipeline(scan_only=True) +train_pipeline = ocr_augmentation_pipeline() general_transform_pipeline = v2.Compose([ v2.ToImage(), # Convert to tensor, only needed if you had a PIL image @@ -76,11 +77,24 @@ def trim_white_border(image: np.ndarray): return trimmed_image +# BUGY CODE def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: randi = [random.randint(0, max_size) for _ in range(4)] + pad_height_size = randi[1] + randi[3] + pad_width_size = randi[0] + randi[2] + if (pad_height_size + image.shape[0] < 30): + compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1 + randi[1] += compensate_height + randi[3] += compensate_height + if (pad_width_size + image.shape[1] < 30): + compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1 + randi[0] += compensate_width + randi[2] += compensate_width return v2.functional.pad( - image, - padding=randi + torch.from_numpy(image).permute(2, 0, 1), + padding=randi, + padding_mode='constant', + fill=(255, 255, 255) ) @@ -127,11 +141,12 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: # 裁剪掉白边 images = [trim_white_border(image) for image in images] # 增加白边 - images = [add_white_border(image, max_size=35) for image in images] + # images = [add_white_border(image, max_size=35) for image in images] # 数据增强 - images = [train_pipeline(image) for image in images] + # images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images] # general transform pipeline - images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] + images = general_transform_pipeline(images) + # images = [general_transform_pipeline(image) for image in images] # padding to fixed size images = padding(images, OCR_IMG_SIZE) return images