This commit is contained in:
三洋三洋
2024-03-26 08:16:28 +00:00
12 changed files with 44709 additions and 45 deletions

View File

@@ -1,13 +1,13 @@
import torch
import numpy as np
from functools import partial
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from ..model.TexTeller import TexTeller
from .transforms import train_transform
from ..model.TexTeller import TexTeller
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
def left_move(x: torch.Tensor, pad_val):
@@ -50,6 +50,13 @@ def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)
if __name__ == '__main__':
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',

View File

@@ -4,7 +4,7 @@ from transformers import EvalPrediction, RobertaTokenizer
from typing import Dict
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住
metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits