Merge branch 'add_ocr_model'
This commit is contained in:
@@ -12,8 +12,25 @@ MIN_WIDTH = 32
|
|||||||
MAX_WIDTH = 1280
|
MAX_WIDTH = 1280
|
||||||
# LaTex-OCR中分别是 32、192、32、672
|
# LaTex-OCR中分别是 32、192、32、672
|
||||||
|
|
||||||
# ocr模型所用数据集中,图片所用的Density渲染值(实际上图片用的渲染Density不是80,而是100)
|
# ocr模型所用数据集,pdf转图片所用的Density值(dpi)
|
||||||
TEXIFY_INPUT_DENSITY = 80
|
TEXIFY_INPUT_DENSITY = 100
|
||||||
|
|
||||||
|
# ocr模型的tokenizer中的词典数量
|
||||||
|
VOCAB_SIZE = 10000
|
||||||
|
|
||||||
|
# ocr模型是否固定输入图片的大小
|
||||||
|
OCR_FIX_SIZE = True
|
||||||
|
# ocr模型训练时,输入图片所固定的大小 (when OCR_FIX_SIZE is True)
|
||||||
|
OCR_IMG_SIZE = 448
|
||||||
|
# ocr模型训练时,输入图片最大的宽和高(when OCR_FIX_SIZE is False)
|
||||||
|
OCR_IMG_MAX_HEIGHT = 512
|
||||||
|
OCR_IMG_MAX_WIDTH = 768
|
||||||
|
|
||||||
|
# ocr模型输入图片的通道数
|
||||||
|
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||||
|
|
||||||
|
# ocr模型训练数据集的最长token数
|
||||||
|
MAX_TOKEN_SIZE = 600
|
||||||
|
|
||||||
# ============================================================================= #
|
# ============================================================================= #
|
||||||
|
|
||||||
|
|||||||
6
src/models/ocr_model/README.md
Normal file
6
src/models/ocr_model/README.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
* Encoder-Decoder架构
|
||||||
|
|
||||||
|
* Encoder使用Deit_{BASE}
|
||||||
|
|
||||||
|
* Decoder使用RoBERTa_{LARGE}
|
||||||
|
* Decoder的tokenizer也使用RoBERTa_{LARGE}的
|
||||||
32
src/models/ocr_model/inference.py
Normal file
32
src/models/ocr_model/inference.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .model.TexTeller import TexTeller
|
||||||
|
from .utils.transforms import inference_transform
|
||||||
|
from ...globals import MAX_TOKEN_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def png2jpg(imgs: List[Image.Image]):
|
||||||
|
imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")]
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
|
def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||||
|
imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs
|
||||||
|
imgs = inference_transform(imgs)
|
||||||
|
pixel_values = torch.stack(imgs)
|
||||||
|
|
||||||
|
generate_config = GenerationConfig(
|
||||||
|
max_new_tokens=MAX_TOKEN_SIZE,
|
||||||
|
num_beams=3,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||||
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
inference()
|
||||||
54
src/models/ocr_model/model/TexTeller.py
Normal file
54
src/models/ocr_model/model/TexTeller.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ....globals import (
|
||||||
|
VOCAB_SIZE,
|
||||||
|
OCR_IMG_SIZE,
|
||||||
|
OCR_IMG_CHANNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
ViTConfig,
|
||||||
|
ViTModel,
|
||||||
|
TrOCRConfig,
|
||||||
|
TrOCRForCausalLM,
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
VisionEncoderDecoderModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TexTeller(VisionEncoderDecoderModel):
|
||||||
|
def __init__(self, decoder_path=None, tokenizer_path=None):
|
||||||
|
encoder = ViTModel(ViTConfig(
|
||||||
|
image_size=OCR_IMG_SIZE,
|
||||||
|
num_channels=OCR_IMG_CHANNELS
|
||||||
|
))
|
||||||
|
decoder = TrOCRForCausalLM(TrOCRConfig(
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
))
|
||||||
|
super().__init__(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: str):
|
||||||
|
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
||||||
|
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# texteller = TexTeller()
|
||||||
|
from ..inference import inference
|
||||||
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
|
||||||
|
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
|
||||||
|
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
|
||||||
|
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
|
||||||
|
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
|
||||||
|
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
|
||||||
|
|
||||||
|
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
|
||||||
|
pause = 1
|
||||||
|
|
||||||
158
src/models/ocr_model/train/train.py
Normal file
158
src/models/ocr_model/train/train.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer
|
||||||
|
from ..model.TexTeller import TexTeller
|
||||||
|
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
|
||||||
|
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
seed=42, # 随机种子,用于确保实验的可重复性
|
||||||
|
use_cpu=False, # 是否使用cpu(刚开始测试代码的时候先用cpu跑会更容易debug)
|
||||||
|
# data_seed=42, # data sampler的采样也固定
|
||||||
|
# full_determinism=True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
||||||
|
|
||||||
|
output_dir="train_result", # 输出目录
|
||||||
|
overwrite_output_dir=False, # 如果输出目录存在,不删除原先的内容
|
||||||
|
report_to=["tensorboard"], # 输出日志到TensorBoard,
|
||||||
|
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
||||||
|
|
||||||
|
logging_dir=None, # TensorBoard日志文件的存储目录(使用默认值)
|
||||||
|
log_level="info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
||||||
|
logging_strategy="steps", # 每隔一定步数记录一次日志
|
||||||
|
logging_steps=500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
#+通常与eval_steps一致
|
||||||
|
logging_nan_inf_filter=False, # 对loss=nan或inf进行记录
|
||||||
|
|
||||||
|
num_train_epochs=10, # 总的训练轮数
|
||||||
|
# max_steps=3, # 训练的最大步骤数。如果设置了这个参数,
|
||||||
|
#+那么num_train_epochs将被忽略(通常用于调试)
|
||||||
|
|
||||||
|
# label_names = ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||||
|
|
||||||
|
per_device_train_batch_size=128, # 每个GPU的batch size
|
||||||
|
per_device_eval_batch_size=16, # 每个GPU的evaluation batch size
|
||||||
|
auto_find_batch_size=True, # 自动搜索合适的batch size(指数decay)
|
||||||
|
|
||||||
|
optim = 'adamw_torch', # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
||||||
|
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
||||||
|
lr_scheduler_type="cosine", # 设置lr_scheduler
|
||||||
|
warmup_ratio=0.1, # warmup占整个训练steps的比例(假如训练1000步,那么前100步就是从lr=0慢慢长到参数设定的lr)
|
||||||
|
# warmup_steps=500, # 预热步数, 这个参数与warmup_ratio是矛盾的
|
||||||
|
weight_decay=0, # 权重衰减
|
||||||
|
learning_rate=5e-5, # 学习率
|
||||||
|
max_grad_norm=1.0, # 用于梯度裁剪,确保梯度的范数不超过1.0(默认1.0)
|
||||||
|
fp16=False, # 是否使用16位浮点数进行训练(一般不推荐,loss很容易炸)
|
||||||
|
bf16=False, # 是否使用16位宽浮点数进行训练(如果架构支持的话推荐使用)
|
||||||
|
gradient_accumulation_steps=2, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果
|
||||||
|
gradient_checkpointing=False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间)
|
||||||
|
label_smoothing_factor=0.0, # softlabel,等于0时表示未开启
|
||||||
|
# debug='underflow_overflow', # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug)
|
||||||
|
jit_mode_eval=True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错)
|
||||||
|
torch_compile=True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能)
|
||||||
|
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
|
||||||
|
# deepspeed='your_json_path', # 使用deepspeed来训练,需要指定ds_config.json的路径
|
||||||
|
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等)
|
||||||
|
#+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现)
|
||||||
|
|
||||||
|
dataloader_pin_memory=True, # 可以加快数据在cpu和gpu之间转移的速度
|
||||||
|
dataloader_num_workers=16, # 默认不会使用多进程来加载数据,通常设成4*所用的显卡数
|
||||||
|
dataloader_drop_last=True, # 丢掉最后一个minibatch,保证训练的梯度稳定
|
||||||
|
|
||||||
|
evaluation_strategy="steps", # 评估策略,可以是"steps"或"epoch"
|
||||||
|
eval_steps=500, # if evaluation_strategy="step"
|
||||||
|
#+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
|
||||||
|
save_strategy="steps", # 保存checkpoint的策略
|
||||||
|
save_steps=500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
save_total_limit=5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
||||||
|
|
||||||
|
load_best_model_at_end=True, # 训练结束时是否加载最佳模型
|
||||||
|
#+当设置True时,会保存训练时评估结果最好的checkpoint
|
||||||
|
#+当设置True时,evaluation_strategy必须与save_strategy一样,并且save_steps必须是eval_steps的整数倍
|
||||||
|
metric_for_best_model="eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用)
|
||||||
|
#+可以使用compute_metrics输出的evaluation的结果中(一个字典)的某个值
|
||||||
|
#+注意:Trainer会在compute_metrics输出的字典的键前面加上一个prefix,默认就是“eval_”
|
||||||
|
greater_is_better=False, # 指标值越小越好(必须与metric_for_best_model一起用)
|
||||||
|
|
||||||
|
do_train=True, # 是否进行训练,通常用于调试
|
||||||
|
do_eval=True, # 是否进行评估,通常用于调试
|
||||||
|
|
||||||
|
remove_unused_columns=False, # 是否删除没有用到的列(特征),默认为True
|
||||||
|
#+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function
|
||||||
|
#+注意:remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对,对于不存在forward方法中的列名就会直接删掉整个feature
|
||||||
|
#+因此如果在dataset.with_transform(..)中给数据进行改名,那么这个remove操作会直接把原始的数据直接删掉,从而导致之后会拿到一个空的dataset,导致在对dataset进行切片取值时出问题
|
||||||
|
#+例如读进来的dataset图片对应的feature name叫"images",而模型forward方法中对应的参数名叫“pixel_values”,
|
||||||
|
#+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数,然后再把"images"改名成“pixel_values”,那么整个过程就会出问题
|
||||||
|
#+因为设置了remove_unused_columns=True后,会先给dataset进行列名检查,然后“images”这个feature会直接被删掉(导致with_transform的transform_fn拿不到“images”这个feature)
|
||||||
|
#+所以一个good practice就是:对于要改名的特征,先提前使用dataset.rename_column进行改名
|
||||||
|
|
||||||
|
push_to_hub=False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# dataset = load_dataset(
|
||||||
|
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
# 'cleaned_formulas'
|
||||||
|
# )['train'].select(range(500))
|
||||||
|
dataset = load_dataset(
|
||||||
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
'cleaned_formulas'
|
||||||
|
)['train']
|
||||||
|
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||||
|
# tokenized_formula = tokenized_formula.to_dict()
|
||||||
|
# tokenized_formula['pixel_values'] = dataset['image']
|
||||||
|
# tokenized_dataset = dataset.from_dict(tokenized_formula)
|
||||||
|
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
||||||
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
|
|
||||||
|
model = TexTeller()
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
training_args,
|
||||||
|
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=collate_fn_with_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=None)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
一个metric_function的另一个case:
|
||||||
|
|
||||||
|
# Setup evaluation
|
||||||
|
metric = evaluate.load("accuracy")
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
logits, labels = eval_pred
|
||||||
|
predictions = np.argmax(logits, axis=-1)
|
||||||
|
return metric.compute(predictions=predictions, references=labels)
|
||||||
|
"""
|
||||||
|
pause = 1
|
||||||
|
model.generate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cur_path = os.getcwd()
|
||||||
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
main()
|
||||||
|
|
||||||
|
os.chdir(cur_path)
|
||||||
|
|
||||||
77
src/models/ocr_model/utils/preprocess.py
Normal file
77
src/models/ocr_model/utils/preprocess.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from ...ocr_model.model.TexTeller import TexTeller
|
||||||
|
from .transforms import train_transform
|
||||||
|
|
||||||
|
|
||||||
|
def left_move(x: torch.Tensor, pad_val):
|
||||||
|
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
||||||
|
lefted_x = torch.ones_like(x)
|
||||||
|
lefted_x[:, :-1] = x[:, 1:]
|
||||||
|
lefted_x[:, -1] = pad_val
|
||||||
|
return lefted_x
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
|
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||||
|
tokenized_formula['pixel_values'] = samples['image']
|
||||||
|
return tokenized_formula
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||||
|
assert tokenizer is not None, 'tokenizer should not be None'
|
||||||
|
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
||||||
|
|
||||||
|
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
batch = clm_collator(samples)
|
||||||
|
batch['pixel_values'] = pixel_values
|
||||||
|
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||||
|
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||||
|
|
||||||
|
# 左移labels和decoder_attention_mask
|
||||||
|
batch['labels'] = left_move(batch['labels'], -100)
|
||||||
|
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||||
|
|
||||||
|
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||||
|
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
|
processed_img = train_transform(samples['pixel_values'])
|
||||||
|
samples['pixel_values'] = processed_img
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset = load_dataset(
|
||||||
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
'cleaned_formulas'
|
||||||
|
)['train'].select(range(20))
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
tokenized_formula = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names)
|
||||||
|
tokenized_formula = tokenized_formula.to_dict()
|
||||||
|
# tokenized_formula['pixel_values'] = dataset['image']
|
||||||
|
# tokenized_formula = dataset.from_dict(tokenized_formula)
|
||||||
|
tokenized_dataset = tokenized_formula.with_transform(img_preprocess)
|
||||||
|
|
||||||
|
dataset_dict = tokenized_dataset[:]
|
||||||
|
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
|
||||||
|
batch = collate_fn_with_tokenizer(dataset_list)
|
||||||
|
|
||||||
|
from ..model.TexTeller import TexTeller
|
||||||
|
model = TexTeller()
|
||||||
|
out = model(**batch)
|
||||||
|
|
||||||
|
pause = 1
|
||||||
|
|
||||||
67
src/models/ocr_model/utils/transforms.py
Normal file
67
src/models/ocr_model/utils/transforms.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from PIL import ImageChops, Image
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
||||||
|
|
||||||
|
|
||||||
|
def trim_white_border(image: Image.Image):
|
||||||
|
if image.mode == 'RGB':
|
||||||
|
bg_color = (255, 255, 255)
|
||||||
|
elif image.mode == 'L':
|
||||||
|
bg_color = 255
|
||||||
|
else:
|
||||||
|
raise ValueError("Only support RGB or L mode")
|
||||||
|
# 创建一个与图片一样大小的白色背景
|
||||||
|
bg = Image.new(image.mode, image.size, bg_color)
|
||||||
|
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
||||||
|
diff = ImageChops.difference(image, bg)
|
||||||
|
# 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。
|
||||||
|
diff = ImageChops.add(diff, diff, 2.0, -100)
|
||||||
|
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
||||||
|
bbox = diff.getbbox()
|
||||||
|
return image.crop(bbox) if bbox else image
|
||||||
|
|
||||||
|
|
||||||
|
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
images = [trim_white_border(image) for image in images]
|
||||||
|
transforms = v2.Compose([
|
||||||
|
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||||
|
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||||
|
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||||
|
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||||
|
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||||
|
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||||
|
|
||||||
|
v2.Resize( # 固定resize到一个正方形上
|
||||||
|
size=OCR_IMG_SIZE - 1, # size必须小于max_size
|
||||||
|
interpolation=v2.InterpolationMode.BICUBIC,
|
||||||
|
max_size=OCR_IMG_SIZE,
|
||||||
|
antialias=True
|
||||||
|
),
|
||||||
|
|
||||||
|
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||||
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||||
|
|
||||||
|
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
||||||
|
])
|
||||||
|
|
||||||
|
images = transforms(images) # imgs: List[PIL.Image.Image]
|
||||||
|
images = [
|
||||||
|
v2.functional.pad(
|
||||||
|
img,
|
||||||
|
padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]]
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
return train_transform(images)
|
||||||
@@ -6,6 +6,7 @@ from ....globals import (
|
|||||||
IMAGE_MEAN, IMAGE_STD,
|
IMAGE_MEAN, IMAGE_STD,
|
||||||
LABEL_RATIO,
|
LABEL_RATIO,
|
||||||
RESIZER_IMG_SIZE,
|
RESIZER_IMG_SIZE,
|
||||||
|
NUM_CHANNELS
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -37,6 +38,7 @@ def preprocess_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|||||||
imgs = [trim_white_border(img) for img in imgs]
|
imgs = [trim_white_border(img) for img in imgs]
|
||||||
labels = [float(img.height * LABEL_RATIO) for img in imgs]
|
labels = [float(img.height * LABEL_RATIO) for img in imgs]
|
||||||
|
|
||||||
|
assert NUM_CHANNELS == 1, "Only support grayscale images"
|
||||||
transform = v2.Compose([
|
transform = v2.Compose([
|
||||||
v2.ToImage(),
|
v2.ToImage(),
|
||||||
v2.ToDtype(torch.uint8, scale=True),
|
v2.ToDtype(torch.uint8, scale=True),
|
||||||
|
|||||||
9740
src/models/tokenizer/roberta-tokenizer-550Kformulas/merges.txt
Normal file
9740
src/models/tokenizer/roberta-tokenizer-550Kformulas/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"cls_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"mask_token": {
|
||||||
|
"content": "<mask>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"sep_token": "</s>",
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}
|
||||||
19830
src/models/tokenizer/roberta-tokenizer-550Kformulas/tokenizer.json
Normal file
19830
src/models/tokenizer/roberta-tokenizer-550Kformulas/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"content": "<mask>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"cls_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"errors": "replace",
|
||||||
|
"mask_token": "<mask>",
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"sep_token": "</s>",
|
||||||
|
"tokenizer_class": "RobertaTokenizer",
|
||||||
|
"trim_offsets": true,
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
21
src/models/tokenizer/roberta-tokenizer-raw/config.json
Normal file
21
src/models/tokenizer/roberta-tokenizer-raw/config.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"RobertaForMaskedLM"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 514,
|
||||||
|
"model_type": "roberta",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"type_vocab_size": 1,
|
||||||
|
"vocab_size": 50265
|
||||||
|
}
|
||||||
50001
src/models/tokenizer/roberta-tokenizer-raw/merges.txt
Normal file
50001
src/models/tokenizer/roberta-tokenizer-raw/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
1
src/models/tokenizer/roberta-tokenizer-raw/vocab.json
Normal file
1
src/models/tokenizer/roberta-tokenizer-raw/vocab.json
Normal file
File diff suppressed because one or more lines are too long
29
src/models/tokenizer/test_long_formulas.txt
Normal file
29
src/models/tokenizer/test_long_formulas.txt
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
\begin{aligned}
|
||||||
|
&\begin{aligned}(\tau\lambda)\psi(a)(\lambda^{-1}\tau)(X,Y,\xi,\eta)=(\tau\lambda)\psi(a)(-\tau Y,\tau X,-\tau\eta,\tau\xi)\end{aligned} \\
|
||||||
|
&=(\tau\lambda)\bigg(\begin{pmatrix}-a\tau\eta_1&-\tau y_3&-\tau\overline{y}_2\\-\tau\overline{y}_3&-a^{-1}\tau\eta_2&-a^{-1}\tau y_1\\-\tau y_2&-a^{-1}\tau\overline{y}_1&-a^{-1}\tau\eta_3\end{pmatrix},\begin{pmatrix}a^{-1}\tau\xi_1&\tau x_3&\tau\overline{x}_2\\\tau\overline{x}_3&a\tau\xi_2&a\tau x_1\\\tau x_2&a\tau\overline{x}_1&a\tau\xi_3\end{pmatrix},-a\tau\eta,a^{-1}\tau\xi\bigg) \\
|
||||||
|
&\left.=\left(\begin{pmatrix}\tau a^{-1}\xi_1&x_3&\overline{x}_2\\\overline{x}_3&\tau a\xi_2&\tau ax_1\\x_2&\tau a\overline{x}_1&\tau a\xi_3\end{pmatrix}\right.,\begin{pmatrix}\tau a\eta_1&y_3&\overline{y}_2\\\overline{y}_3&\tau a^{-1}\eta_2&\tau a^{-1}y_1\\y_2&\tau a^{-1}\overline{y}_1&\tau a^{-1}\eta_3\end{pmatrix},\tau a^{-1}\xi,\tau a\eta\right) \\
|
||||||
|
&=\psi(\tau a^{-1}).
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
\begin{aligned}
|
||||||
|
&\begin{aligned}-L_{X_{13}}&=\left(\frac{1}{2}\sin\alpha\cos\beta\sin2\gamma+\cos\alpha\tan\beta\sin^2\gamma-\frac{1}{2}\sin\alpha\sin\beta\tan\beta\sin2\gamma\right)\frac{\partial}{\partial\alpha}\end{aligned} \\
|
||||||
|
&\begin{aligned}+\left(\frac12\cos\alpha\sin\beta\sin2\gamma-\sin\alpha\sin^2\beta\cos^2\gamma-\sin\alpha\cos^2\beta\sin^2\gamma\right)\frac\partial{\partial\beta}\end{aligned} \\
|
||||||
|
&\begin{aligned}+\left(\frac14\sin\alpha\sin2\beta\sin2\gamma-\frac12\sin\alpha\tan\beta\sin2\gamma+\cos\alpha\sec\beta\sin^2\gamma\right)\frac{\partial}{\partial\gamma}\end{aligned} \\
|
||||||
|
&+\left(\left(\frac12\sin\alpha\sin2\beta\cos^2\gamma+\frac12\sin\alpha\sin2\beta-\frac12\cos\alpha\cos\beta\sin2\gamma\right)z_{12}\right. \\
|
||||||
|
&+(\sin\alpha\cos2\beta\cos\gamma+\cos\alpha\sin\beta\sin\gamma)\biggr)\frac{\partial}{\partial z_{12}} \\
|
||||||
|
&+\left(\left(\frac12\sin\alpha\sin2\beta\cos2\gamma-\cos\alpha\cos\beta\sin2\gamma\right)z_{13}+(\sin\alpha\cos2\beta\cos\gamma\right. \\
|
||||||
|
&\left.\left.+\cos\alpha\sin\beta\sin\gamma\right)z_{23}+\left(\frac12\sin\alpha\sin2\beta\sin2\gamma+\cos\alpha\cos\beta\cos2\gamma\right)\right)\frac{\partial}{\partial z_{13}} \\
|
||||||
|
&+\left(\left(-\frac12\sin\alpha\sin2\beta-\frac12\sin\alpha\sin2\beta\sin^2\gamma-\frac12\cos\alpha\cos\beta\sin2\gamma\right)z_{23}\right. \\
|
||||||
|
&+(\sin\alpha\cos2\beta\sin\gamma-\cos\alpha\sin\beta\cos\gamma)\Bigg)\frac{\partial}{\partial z_{23}}.
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
\begin{aligned}
|
||||||
|
&\sum_S(-1)^{|S|}\frac{1-\prod_{i\notin S}\left(\frac{X_i(1+X_i)}{Q+X_i}\right)^{m+1}}{1-\prod_{i\notin S}\frac{X_i(1+X_i)}{Q+X_i}}\prod_iX_i \\
|
||||||
|
&\times\prod_{i\in S}X_{i}^{m+n-1}(1+X_{i})^{m+1}(Q+X_{i})^{-m}(X_{i}+r+Q)^{n-1} \\
|
||||||
|
&\times\prod_{i\notin S}(1+X_i)(Q+rX_i+QX_i)^{n-1} \\
|
||||||
|
&&\times\prod_{1\leq i<j\leq n,\{i,j\}\cap S\neq\emptyset}\left(\frac{Y_j(1+Y_j)}{Q+rY_j+QY_j}-\frac{Y_i(1+Y_i)}{Q+rY_i+QY_i}\right) \\
|
||||||
|
&&&\times\sum_{k\notin S}(Q-X_{k}^{2})X_{k}^{-1}(1+X_{k})^{-1} \\
|
||||||
|
&&&\times\prod_{\overset{1\leq i\leq k-1}{i\notin S}}\frac{(Q+(Q+r)X_k+X_i+X_iX_k)(X_iX_k-Q)}{(Q+rX_k+QX_k)(Q+rX_i+QX_i)} \\
|
||||||
|
&&&\times\prod_{\overset{k+1\leq i\leq n}{i\notin S}}\frac{(Q+(Q+r)X_k+X_i+X_iX_k)(Q-X_iX_k)}{(Q+rX_k+QX_k)(Q+rX_i+QX_i)} \\
|
||||||
|
&&&&\times\prod_{1\leq i<j\leq n,i,j\notin S\cup\{k\}}\left(\frac{X_j(1+X_j)}{Q+rX_j+QX_j}-\frac{X_i(1+X_i)}{Q+rX_i+QX_i}\right).
|
||||||
|
\end{aligned}
|
||||||
11
src/models/tokenizer/train/train.py
Normal file
11
src/models/tokenizer/train/train.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from datasets import load_dataset
|
||||||
|
from ...ocr_model.model.TexTeller import TexTeller
|
||||||
|
from ....globals import VOCAB_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw')
|
||||||
|
dataset = load_dataset("/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py", "cleaned_formulas")['train']
|
||||||
|
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE)
|
||||||
|
new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
pause = 1
|
||||||
Reference in New Issue
Block a user