完成了所有代码
This commit is contained in:
@@ -38,7 +38,6 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
# batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
@@ -76,48 +75,3 @@ if __name__ == '__main__':
|
||||
out = model(**batch)
|
||||
|
||||
pause = 1
|
||||
|
||||
|
||||
'''
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch['pixel_values'] = pixel_values
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
'''
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from PIL import ImageChops, Image
|
||||
from typing import List, Union
|
||||
|
||||
from ....globals import (
|
||||
@@ -107,7 +106,7 @@ def random_resize(
|
||||
]
|
||||
|
||||
|
||||
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
@@ -117,16 +116,16 @@ def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
return images
|
||||
|
||||
|
||||
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
|
||||
# random resize first
|
||||
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user