修改好了训练,加入了数据增强
This commit is contained in:
@@ -21,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val):
|
|||||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||||
tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']]
|
tokenized_formula['pixel_values'] = samples['image']
|
||||||
return tokenized_formula
|
return tokenized_formula
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import cv2
|
|||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
from augraphy import *
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...globals import (
|
from ...globals import (
|
||||||
@@ -15,12 +16,13 @@ from ...globals import (
|
|||||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_pipeline = default_augraphy_pipeline()
|
||||||
|
|
||||||
general_transform_pipeline = v2.Compose([
|
general_transform_pipeline = v2.Compose([
|
||||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||||
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||||
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||||
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||||
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||||
|
|
||||||
@@ -74,7 +76,15 @@ def trim_white_border(image: np.ndarray):
|
|||||||
return trimmed_image
|
return trimmed_image
|
||||||
|
|
||||||
|
|
||||||
def padding(images: List[torch.Tensor], required_size: int):
|
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||||
|
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||||
|
return v2.functional.pad(
|
||||||
|
image,
|
||||||
|
padding=randi
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||||
images = [
|
images = [
|
||||||
v2.functional.pad(
|
v2.functional.pad(
|
||||||
img,
|
img,
|
||||||
@@ -107,9 +117,19 @@ def random_resize(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
|
||||||
|
images = [np.array(img.convert('RGB')) for img in images]
|
||||||
|
# random resize first
|
||||||
|
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||||
# 裁剪掉白边
|
# 裁剪掉白边
|
||||||
images = [trim_white_border(image) for image in images]
|
images = [trim_white_border(image) for image in images]
|
||||||
|
# 增加白边
|
||||||
|
images = [add_white_border(image, max_size=35) for image in images]
|
||||||
|
# 数据增强
|
||||||
|
images = [train_pipeline(image) for image in images]
|
||||||
# general transform pipeline
|
# general transform pipeline
|
||||||
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
||||||
# padding to fixed size
|
# padding to fixed size
|
||||||
@@ -117,21 +137,17 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
|
||||||
|
|
||||||
# random resize first
|
|
||||||
images = [np.array(img.convert('RGB')) for img in images]
|
|
||||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
|
||||||
return general_transform(images)
|
|
||||||
|
|
||||||
|
|
||||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
# 裁剪掉白边
|
||||||
|
images = [trim_white_border(image) for image in images]
|
||||||
|
# general transform pipeline
|
||||||
|
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
||||||
|
# padding to fixed size
|
||||||
|
images = padding(images, OCR_IMG_SIZE)
|
||||||
|
|
||||||
return general_transform(images)
|
return images
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user