完成了web,ray server,重构了代码

This commit is contained in:
三洋三洋
2024-02-08 13:48:34 +00:00
parent 07c4c3dc01
commit 04b99b8451
20 changed files with 245 additions and 57 deletions

59
src/models/globals.py Normal file
View File

@@ -0,0 +1,59 @@
# 公式图片(灰度化后)的均值和方差
IMAGE_MEAN = 0.9545467
IMAGE_STD = 0.15394445
# ========================= ocr模型用的参数 ============================= #
# 输入图片的最大最小的宽和高
MIN_HEIGHT = 32
MAX_HEIGHT = 512
MIN_WIDTH = 32
MAX_WIDTH = 1280
# LaTex-OCR中分别是 32、192、32、672
# ocr模型所用数据集pdf转图片所用的Density值(dpi)
TEXIFY_INPUT_DENSITY = 100
# ocr模型的tokenizer中的词典数量
VOCAB_SIZE = 10000
# ocr模型是否固定输入图片的大小
OCR_FIX_SIZE = True
# ocr模型训练时输入图片所固定的大小 (when OCR_FIX_SIZE is True)
OCR_IMG_SIZE = 448
# ocr模型训练时输入图片最大的宽和高when OCR_FIX_SIZE is False
OCR_IMG_MAX_HEIGHT = 512
OCR_IMG_MAX_WIDTH = 768
# ocr模型输入图片的通道数
OCR_IMG_CHANNELS = 1 # 灰度图
# ocr模型训练数据集的最长token数
MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512所以这里必须是512
# MAX_TOKEN_SIZE = 600
# ocr模型训练时随机缩放的比例
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# ocr模型输入的图片要求的最低宽和高(过滤垃圾数据)
MIN_HEIGHT = 12
MIN_WIDTH = 30
# ============================================================================= #
# ========================= Resizer模型用的参数 ============================= #
# Resizer模型所用数据集中图片所用的Density渲染值
RESIZER_INPUT_DENSITY = 200
LABEL_RATIO = 1.0 * TEXIFY_INPUT_DENSITY / RESIZER_INPUT_DENSITY
NUM_CLASSES = 1 # 模型使用回归预测
NUM_CHANNELS = 1 # 输入单通道图片(灰度图)
# Resizer在训练时图片所固定的的大小
RESIZER_IMG_SIZE = 448
# ============================================================================= #

View File

@@ -1,34 +0,0 @@
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform
from .utils.helpers import convert2rgb
from ...globals import MAX_TOKEN_SIZE
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res
if __name__ == '__main__':
inference()

View File

@@ -1,6 +1,7 @@
from PIL import Image
from pathlib import Path
from ....globals import (
from models.globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS,
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')

View File

@@ -11,7 +11,7 @@ from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -82,10 +82,10 @@ if __name__ == '__main__':
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
enable_train = True
enable_evaluate = False
enable_train = False
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:

View File

@@ -4,12 +4,14 @@ from typing import List
from PIL import Image
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
# 输出的np.ndarray的格式为[H, W, C]通道在第三维通道的排列顺序为RGB
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
processed_images.append(image)
return processed_images

View File

@@ -0,0 +1,39 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -7,7 +7,7 @@ from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from ....globals import (
from ...globals import (
OCR_IMG_CHANNELS,
OCR_IMG_SIZE,
OCR_FIX_SIZE,

View File

@@ -10,7 +10,7 @@ from transformers import (
from ..utils import preprocess_fn
from ..model.Resizer import Resizer
from ....globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
from ...globals import NUM_CHANNELS, NUM_CLASSES, RESIZER_IMG_SIZE
def train():

View File

@@ -2,7 +2,7 @@ import torch
from torchvision.transforms import v2
from PIL import Image, ImageChops
from ....globals import (
from ...globals import (
IMAGE_MEAN, IMAGE_STD,
LABEL_RATIO,
RESIZER_IMG_SIZE,

View File

@@ -1,6 +1,6 @@
from datasets import load_dataset
from ...ocr_model.model.TexTeller import TexTeller
from ....globals import VOCAB_SIZE
from ...globals import VOCAB_SIZE
if __name__ == '__main__':