tmp commit
This commit is contained in:
@@ -55,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
model,
|
||||
seq2seq_config,
|
||||
|
||||
eval_dataset=eval_dataset,
|
||||
eval_dataset=eval_dataset.select(range(100)),
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||
@@ -73,20 +73,20 @@ if __name__ == '__main__':
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
|
||||
# dataset = load_dataset(
|
||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
# 'cleaned_formulas'
|
||||
# )['train']
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
'cleaned_formulas'
|
||||
)['train'].select(range(1000))
|
||||
)['train']
|
||||
# dataset = load_dataset(
|
||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
# 'cleaned_formulas'
|
||||
# )['train'].select(range(1000))
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
||||
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||
@@ -105,3 +105,41 @@ if __name__ == '__main__':
|
||||
|
||||
os.chdir(cur_path)
|
||||
|
||||
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
cur_path = os.getcwd()
|
||||
script_dirpath = Path(__file__).resolve().parent
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
'cleaned_formulas'
|
||||
)['train']
|
||||
|
||||
pause = dataset[0]['image']
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
||||
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-81000')
|
||||
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
|
||||
|
||||
os.chdir(cur_path)
|
||||
|
||||
|
||||
'''
|
||||
Reference in New Issue
Block a user