Files
TexTeller/texteller/models/ocr_model/utils/metrics.py

26 lines
803 B
Python
Raw Normal View History

2024-01-30 08:36:23 +00:00
import evaluate
import numpy as np
2024-04-21 00:48:24 +08:00
import os
from pathlib import Path
2024-01-30 08:36:23 +00:00
from typing import Dict
2024-04-21 00:48:24 +08:00
from transformers import EvalPrediction, RobertaTokenizer
2024-01-30 08:36:23 +00:00
2024-04-21 00:48:24 +08:00
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load(
'google_bleu'
) # Will download the metric from huggingface if not already downloaded
2024-04-21 00:48:24 +08:00
os.chdir(cur_dir)
2024-01-30 08:36:23 +00:00
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
2024-04-21 00:48:24 +08:00
return metric.compute(predictions=preds, references=labels)