merge dev后调整了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:48:24 +08:00
parent eac7f455d6
commit 9b7e392c66
66 changed files with 190 additions and 124855 deletions

View File

@@ -1,17 +1,23 @@
import evaluate
import numpy as np
from transformers import EvalPrediction, RobertaTokenizer
from typing import Dict
import os
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
# preds = np.argmax(logits, axis=1) # 把logits转成对应的预测标签
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)
return metric.compute(predictions=preds, references=labels)