diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index a597fb5..05e7ec2 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -55,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn): model, seq2seq_config, - eval_dataset=eval_dataset, + eval_dataset=eval_dataset.select(range(100)), tokenizer=tokenizer, data_collator=collate_fn, compute_metrics=partial(bleu_metric, tokenizer=tokenizer) @@ -73,20 +73,20 @@ if __name__ == '__main__': os.chdir(script_dirpath) - # dataset = load_dataset( - # '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', - # 'cleaned_formulas' - # )['train'] dataset = load_dataset( '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', 'cleaned_formulas' - )['train'].select(range(1000)) + )['train'] + # dataset = load_dataset( + # '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', + # 'cleaned_formulas' + # )['train'].select(range(1000)) tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') map_fn = partial(tokenize_fn, tokenizer=tokenizer) - # tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False) - tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False) + tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True) + # tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1) tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) @@ -105,3 +105,41 @@ if __name__ == '__main__': os.chdir(cur_path) + +''' +if __name__ == '__main__': + cur_path = os.getcwd() + script_dirpath = Path(__file__).resolve().parent + os.chdir(script_dirpath) + + + dataset = load_dataset( + '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', + 'cleaned_formulas' + )['train'] + + pause = dataset[0]['image'] + tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') + + map_fn = partial(tokenize_fn, tokenizer=tokenizer) + tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) + tokenized_dataset = tokenized_dataset.with_transform(img_preprocess) + + split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) + train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) + # model = TexTeller() + model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-81000') + + enable_train = False + enable_evaluate = True + if enable_train: + train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) + if enable_evaluate: + evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) + + + os.chdir(cur_path) + + +''' \ No newline at end of file diff --git a/src/models/ocr_model/utils/functional.py b/src/models/ocr_model/utils/functional.py index b7cecc1..748f42c 100644 --- a/src/models/ocr_model/utils/functional.py +++ b/src/models/ocr_model/utils/functional.py @@ -38,6 +38,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[ # 左移labels和decoder_attention_mask batch['labels'] = left_move(batch['labels'], -100) + # batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0) # 把list of Image转成一个tensor with (B, C, H, W) batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) @@ -76,3 +77,47 @@ if __name__ == '__main__': pause = 1 + +''' +def left_move(x: torch.Tensor, pad_val): + assert len(x.shape) == 2, 'x should be 2-dimensional' + lefted_x = torch.ones_like(x) + lefted_x[:, :-1] = x[:, 1:] + lefted_x[:, -1] = pad_val + return lefted_x + + +def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: + assert tokenizer is not None, 'tokenizer should not be None' + tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) + tokenized_formula['pixel_values'] = samples['image'] + return tokenized_formula + + +def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]: + assert tokenizer is not None, 'tokenizer should not be None' + pixel_values = [dic.pop('pixel_values') for dic in samples] + + clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + batch = clm_collator(samples) + batch['pixel_values'] = pixel_values + batch['decoder_input_ids'] = batch.pop('input_ids') + batch['decoder_attention_mask'] = batch.pop('attention_mask') + + # 左移labels和decoder_attention_mask + batch['labels'] = left_move(batch['labels'], -100) + batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0) + + # 把list of Image转成一个tensor with (B, C, H, W) + batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) + return batch + + +def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + processed_img = train_transform(samples['pixel_values']) + samples['pixel_values'] = processed_img + return samples + +''' +