修复了bug:当样本中出现非常长的公式(对应的token数可能超过2048),会导致给label进行embedding时index out of range
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from ..model.TexTeller import TexTeller
|
||||
from .transforms import train_transform
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
@@ -50,6 +50,13 @@ def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
|
||||
Reference in New Issue
Block a user