初步修改完成,但仍然有问题
This commit is contained in:
@@ -41,6 +41,7 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
pause = 1
|
||||||
# texteller = TexTeller()
|
# texteller = TexTeller()
|
||||||
# from ..utils.inference import inference
|
# from ..utils.inference import inference
|
||||||
# model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
# model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
|||||||
|
|
||||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||||
training_args = TrainingArguments(**CONFIG)
|
training_args = TrainingArguments(**CONFIG)
|
||||||
debug_mode = False
|
debug_mode = True
|
||||||
if debug_mode:
|
if debug_mode:
|
||||||
training_args.auto_find_batch_size = False
|
training_args.auto_find_batch_size = False
|
||||||
training_args.num_train_epochs = 2
|
training_args.num_train_epochs = 2
|
||||||
@@ -96,6 +96,10 @@ if __name__ == '__main__':
|
|||||||
# model = TexTeller()
|
# model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||||
|
|
||||||
|
# ================= debug =======================
|
||||||
|
foo = train_dataset[:3]
|
||||||
|
# ================= debug =======================
|
||||||
|
|
||||||
enable_train = True
|
enable_train = True
|
||||||
enable_evaluate = True
|
enable_evaluate = True
|
||||||
if enable_train:
|
if enable_train:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import cv2
|
|||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from augraphy import *
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...globals import (
|
from ...globals import (
|
||||||
@@ -15,8 +14,10 @@ from ...globals import (
|
|||||||
IMAGE_MEAN, IMAGE_STD,
|
IMAGE_MEAN, IMAGE_STD,
|
||||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||||
)
|
)
|
||||||
|
from .ocr_aug import ocr_augmentation_pipeline
|
||||||
|
|
||||||
train_pipeline = default_augraphy_pipeline(scan_only=True)
|
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||||
|
train_pipeline = ocr_augmentation_pipeline()
|
||||||
|
|
||||||
general_transform_pipeline = v2.Compose([
|
general_transform_pipeline = v2.Compose([
|
||||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||||
@@ -76,11 +77,24 @@ def trim_white_border(image: np.ndarray):
|
|||||||
return trimmed_image
|
return trimmed_image
|
||||||
|
|
||||||
|
|
||||||
|
# BUGY CODE
|
||||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||||
|
pad_height_size = randi[1] + randi[3]
|
||||||
|
pad_width_size = randi[0] + randi[2]
|
||||||
|
if (pad_height_size + image.shape[0] < 30):
|
||||||
|
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||||
|
randi[1] += compensate_height
|
||||||
|
randi[3] += compensate_height
|
||||||
|
if (pad_width_size + image.shape[1] < 30):
|
||||||
|
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||||
|
randi[0] += compensate_width
|
||||||
|
randi[2] += compensate_width
|
||||||
return v2.functional.pad(
|
return v2.functional.pad(
|
||||||
image,
|
torch.from_numpy(image).permute(2, 0, 1),
|
||||||
padding=randi
|
padding=randi,
|
||||||
|
padding_mode='constant',
|
||||||
|
fill=(255, 255, 255)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,11 +141,12 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
|||||||
# 裁剪掉白边
|
# 裁剪掉白边
|
||||||
images = [trim_white_border(image) for image in images]
|
images = [trim_white_border(image) for image in images]
|
||||||
# 增加白边
|
# 增加白边
|
||||||
images = [add_white_border(image, max_size=35) for image in images]
|
# images = [add_white_border(image, max_size=35) for image in images]
|
||||||
# 数据增强
|
# 数据增强
|
||||||
images = [train_pipeline(image) for image in images]
|
# images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images]
|
||||||
# general transform pipeline
|
# general transform pipeline
|
||||||
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
images = general_transform_pipeline(images)
|
||||||
|
# images = [general_transform_pipeline(image) for image in images]
|
||||||
# padding to fixed size
|
# padding to fixed size
|
||||||
images = padding(images, OCR_IMG_SIZE)
|
images = padding(images, OCR_IMG_SIZE)
|
||||||
return images
|
return images
|
||||||
|
|||||||
Reference in New Issue
Block a user