完成了web,ray server,重构了代码

This commit is contained in:
三洋三洋
2024-02-08 13:48:34 +00:00
parent 07c4c3dc01
commit 04b99b8451
20 changed files with 245 additions and 57 deletions

View File

@@ -1,34 +0,0 @@
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform
from .utils.helpers import convert2rgb
from ...globals import MAX_TOKEN_SIZE
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res
if __name__ == '__main__':
inference()

View File

@@ -1,6 +1,7 @@
from PIL import Image
from pathlib import Path
from ....globals import (
from models.globals import (
VOCAB_SIZE,
OCR_IMG_SIZE,
OCR_IMG_CHANNELS,
@@ -29,16 +30,18 @@ class TexTeller(VisionEncoderDecoderModel):
@classmethod
def from_pretrained(cls, model_path: str):
return VisionEncoderDecoderModel.from_pretrained(model_path)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str) -> RobertaTokenizerFast:
return RobertaTokenizerFast.from_pretrained(tokenizer_path)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')

View File

@@ -11,7 +11,7 @@ from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -82,10 +82,10 @@ if __name__ == '__main__':
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
enable_train = True
enable_evaluate = False
enable_train = False
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:

View File

@@ -4,12 +4,14 @@ from typing import List
from PIL import Image
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
# 输出的np.ndarray的格式为[H, W, C]通道在第三维通道的排列顺序为RGB
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
processed_images.append(image)
return processed_images

View File

@@ -0,0 +1,39 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -7,7 +7,7 @@ from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from ....globals import (
from ...globals import (
OCR_IMG_CHANNELS,
OCR_IMG_SIZE,
OCR_FIX_SIZE,