完成了web,ray server,重构了代码
This commit is contained in:
59
src/models/globals.py
Normal file
59
src/models/globals.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# 公式图片(灰度化后)的均值和方差
|
||||
IMAGE_MEAN = 0.9545467
|
||||
IMAGE_STD = 0.15394445
|
||||
|
||||
|
||||
# ========================= ocr模型用的参数 ============================= #
|
||||
|
||||
# 输入图片的最大最小的宽和高
|
||||
MIN_HEIGHT = 32
|
||||
MAX_HEIGHT = 512
|
||||
MIN_WIDTH = 32
|
||||
MAX_WIDTH = 1280
|
||||
# LaTex-OCR中分别是 32、192、32、672
|
||||
|
||||
# ocr模型所用数据集,pdf转图片所用的Density值(dpi)
|
||||
TEXIFY_INPUT_DENSITY = 100
|
||||
|
||||
# ocr模型的tokenizer中的词典数量
|
||||
VOCAB_SIZE = 10000
|
||||
|
||||
# ocr模型是否固定输入图片的大小
|
||||
OCR_FIX_SIZE = True
|
||||
# ocr模型训练时,输入图片所固定的大小 (when OCR_FIX_SIZE is True)
|
||||
OCR_IMG_SIZE = 448
|
||||
# ocr模型训练时,输入图片最大的宽和高(when OCR_FIX_SIZE is False)
|
||||
OCR_IMG_MAX_HEIGHT = 512
|
||||
OCR_IMG_MAX_WIDTH = 768
|
||||
|
||||
# ocr模型输入图片的通道数
|
||||
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||
|
||||
# ocr模型训练数据集的最长token数
|
||||
MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512,所以这里必须是512
|
||||
# MAX_TOKEN_SIZE = 600
|
||||
|
||||
# ocr模型训练时随机缩放的比例
|
||||
MAX_RESIZE_RATIO = 1.15
|
||||
MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# ocr模型输入的图片要求的最低宽和高(过滤垃圾数据)
|
||||
MIN_HEIGHT = 12
|
||||
MIN_WIDTH = 30
|
||||
|
||||
# ============================================================================= #
|
||||
|
||||
|
||||
# ========================= Resizer模型用的参数 ============================= #
|
||||
|
||||
# Resizer模型所用数据集中,图片所用的Density渲染值
|
||||
RESIZER_INPUT_DENSITY = 200
|
||||
|
||||
LABEL_RATIO = 1.0 * TEXIFY_INPUT_DENSITY / RESIZER_INPUT_DENSITY
|
||||
|
||||
NUM_CLASSES = 1 # 模型使用回归预测
|
||||
NUM_CHANNELS = 1 # 输入单通道图片(灰度图)
|
||||
|
||||
# Resizer在训练时,图片所固定的的大小
|
||||
RESIZER_IMG_SIZE = 448
|
||||
# ============================================================================= #
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
from .model.TexTeller import TexTeller
|
||||
from .utils.transforms import inference_transform
|
||||
from .utils.helpers import convert2rgb
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=3,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference()
|
||||
@@ -1,6 +1,7 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from ....globals import (
|
||||
from models.globals import (
|
||||
VOCAB_SIZE,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_IMG_CHANNELS,
|
||||
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
return VisionEncoderDecoderModel.from_pretrained(model_path)
|
||||
model_path = Path(model_path).resolve()
|
||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
|
||||
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
|
||||
tokenizer_path = Path(tokenizer_path).resolve()
|
||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# texteller = TexTeller()
|
||||
from ..inference import inference
|
||||
from ..utils.inference import inference
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
@@ -82,10 +82,10 @@ if __name__ == '__main__':
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||
|
||||
enable_train = True
|
||||
enable_evaluate = False
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate:
|
||||
|
||||
@@ -4,12 +4,14 @@ from typing import List
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
# 输出的np.ndarray的格式为:[H, W, C](通道在第三维),通道的排列顺序为RGB
|
||||
processed_images = []
|
||||
|
||||
for path in image_paths:
|
||||
# 读取图片
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
# 如果是 BGR (3通道), 转换为 RGB
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(Image.fromarray(image))
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
39
src/models/ocr_model/utils/inference.py
Normal file
39
src/models/ocr_model/utils/inference.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.transforms import inference_transform
|
||||
from models.ocr_model.utils.helpers import convert2rgb
|
||||
from models.globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs_path: List[str],
|
||||
use_cuda: bool,
|
||||
num_beams: int = 1,
|
||||
) -> List[str]:
|
||||
model.eval()
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if use_cuda:
|
||||
model = model.to('cuda')
|
||||
pixel_values = pixel_values.to('cuda')
|
||||
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
@@ -7,7 +7,7 @@ from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
|
||||
from ....globals import (
|
||||
from ...globals import (
|
||||
OCR_IMG_CHANNELS,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_FIX_SIZE,
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers import (
|
||||
|
||||
from ..utils import preprocess_fn
|
||||
from ..model.Resizer import Resizer
|
||||
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
|
||||
from ...globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
|
||||
|
||||
|
||||
def train():
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from PIL import Image, ImageChops
|
||||
from ....globals import (
|
||||
from ...globals import (
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
LABEL_RATIO,
|
||||
RESIZER_IMG_SIZE,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datasets import load_dataset
|
||||
from ...ocr_model.model.TexTeller import TexTeller
|
||||
from ....globals import VOCAB_SIZE
|
||||
from ...globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user