完成了所有代码

This commit is contained in:
三洋三洋
2024-01-31 15:27:35 +00:00
parent ebac28a90d
commit ab1a05bf32
10 changed files with 19 additions and 116 deletions

View File

@@ -32,6 +32,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
eval_config['predict_with_generate'] = True
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
@@ -40,106 +41,54 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
# eval_config['use_cpu'] = True
eval_config['output_dir'] = 'debug_dir'
eval_config['predict_with_generate'] = True
eval_config['predict_with_generate'] = True
eval_config['dataloader_num_workers'] = 1
eval_config['jit_mode_eval'] = False
eval_config['torch_compile'] = False
eval_config['auto_find_batch_size'] = False
eval_config['generation_config'] = generate_config
eval_config['auto_find_batch_size'] = False
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
eval_dataset=eval_dataset.select(range(100)),
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
)
res = trainer.evaluate()
pause = 1
...
print(res)
if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train']
# dataset = load_dataset(
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
# 'cleaned_formulas'
# )['train'].select(range(1000))
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500')
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
enable_train = False
enable_evaluate = True
enable_train = True
enable_evaluate = False
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
os.chdir(cur_path)
'''
if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train']
pause = dataset[0]['image']
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-81000')
enable_train = False
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
os.chdir(cur_path)
'''
os.chdir(cur_path)