tmp commit

This commit is contained in:
三洋三洋
2024-01-31 10:20:27 +00:00
parent 1fba652766
commit ebac28a90d
2 changed files with 91 additions and 8 deletions

View File

@@ -38,6 +38,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
# batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
@@ -76,3 +77,47 @@ if __name__ == '__main__':
pause = 1
'''
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
'''