tmp commit
This commit is contained in:
@@ -55,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|||||||
model,
|
model,
|
||||||
seq2seq_config,
|
seq2seq_config,
|
||||||
|
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset.select(range(100)),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||||
@@ -73,20 +73,20 @@ if __name__ == '__main__':
|
|||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
|
||||||
# dataset = load_dataset(
|
|
||||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
|
||||||
# 'cleaned_formulas'
|
|
||||||
# )['train']
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
'cleaned_formulas'
|
'cleaned_formulas'
|
||||||
)['train'].select(range(1000))
|
)['train']
|
||||||
|
# dataset = load_dataset(
|
||||||
|
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
# 'cleaned_formulas'
|
||||||
|
# )['train'].select(range(1000))
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False)
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False)
|
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1)
|
||||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||||
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
@@ -105,3 +105,41 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
os.chdir(cur_path)
|
os.chdir(cur_path)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cur_path = os.getcwd()
|
||||||
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
'cleaned_formulas'
|
||||||
|
)['train']
|
||||||
|
|
||||||
|
pause = dataset[0]['image']
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||||
|
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
||||||
|
|
||||||
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
# model = TexTeller()
|
||||||
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-81000')
|
||||||
|
|
||||||
|
enable_train = False
|
||||||
|
enable_evaluate = True
|
||||||
|
if enable_train:
|
||||||
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
if enable_evaluate:
|
||||||
|
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
os.chdir(cur_path)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
@@ -38,6 +38,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
|||||||
|
|
||||||
# 左移labels和decoder_attention_mask
|
# 左移labels和decoder_attention_mask
|
||||||
batch['labels'] = left_move(batch['labels'], -100)
|
batch['labels'] = left_move(batch['labels'], -100)
|
||||||
|
# batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||||
|
|
||||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||||
@@ -76,3 +77,47 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
pause = 1
|
pause = 1
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
def left_move(x: torch.Tensor, pad_val):
|
||||||
|
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||||
|
lefted_x = torch.ones_like(x)
|
||||||
|
lefted_x[:, :-1] = x[:, 1:]
|
||||||
|
lefted_x[:, -1] = pad_val
|
||||||
|
return lefted_x
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
|
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||||
|
tokenized_formula['pixel_values'] = samples['image']
|
||||||
|
return tokenized_formula
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
|
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||||
|
|
||||||
|
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
batch = clm_collator(samples)
|
||||||
|
batch['pixel_values'] = pixel_values
|
||||||
|
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||||
|
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||||
|
|
||||||
|
# 左移labels和decoder_attention_mask
|
||||||
|
batch['labels'] = left_move(batch['labels'], -100)
|
||||||
|
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||||
|
|
||||||
|
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||||
|
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
processed_img = train_transform(samples['pixel_values'])
|
||||||
|
samples['pixel_values'] = processed_img
|
||||||
|
return samples
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user