1) 加入了推理代码; 2) 整理了其他代码

This commit is contained in:
三洋三洋
2024-01-28 14:03:42 +00:00
parent c6d5c91955
commit 14125da26f
4 changed files with 52 additions and 18 deletions

View File

@@ -0,0 +1,32 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform
from ...globals import MAX_TOKEN_SIZE
def png2jpg(imgs: List[Image.Image]):
imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")]
return imgs
def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res
if __name__ == '__main__':
inference()

View File

@@ -1,19 +1,18 @@
from PIL import Image
from ....globals import ( from ....globals import (
VOCAB_SIZE, VOCAB_SIZE,
OCR_IMG_SIZE, OCR_IMG_SIZE,
OCR_IMG_CHANNELS OCR_IMG_CHANNELS,
) )
from transformers import ( from transformers import (
ViTConfig, ViTConfig,
ViTModel, ViTModel,
TrOCRConfig, TrOCRConfig,
TrOCRForCausalLM, TrOCRForCausalLM,
RobertaTokenizerFast, RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderModel
) )
@@ -38,9 +37,18 @@ class TexTeller(VisionEncoderDecoderModel):
if __name__ == "__main__": if __name__ == "__main__":
texteller = TexTeller() # texteller = TexTeller()
tokenizer = texteller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') from ..inference import inference
foo = ["Hello, my name is LHY.", "I am a researcher at the University of Science and Technology of China."] model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
bar = tokenizer(foo, return_special_tokens_mask=True) tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
pause = 1 pause = 1

View File

@@ -1,6 +1,4 @@
import torch import torch
import datasets
from datasets import load_dataset from datasets import load_dataset
from functools import partial from functools import partial

View File

@@ -1,9 +1,8 @@
import torch import torch
import torchvision
from torchvision.transforms import v2 from torchvision.transforms import v2
from PIL import ImageChops, Image from PIL import ImageChops, Image
from typing import Any, Dict, List from typing import List
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
@@ -11,12 +10,10 @@ from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN
def trim_white_border(image: Image.Image): def trim_white_border(image: Image.Image):
if image.mode == 'RGB': if image.mode == 'RGB':
bg_color = (255, 255, 255) bg_color = (255, 255, 255)
elif image.mode == 'RGBA':
bg_color = (255, 255, 255, 255)
elif image.mode == 'L': elif image.mode == 'L':
bg_color = 255 bg_color = 255
else: else:
raise ValueError("Unsupported image mode") raise ValueError("Only support RGB or L mode")
# 创建一个与图片一样大小的白色背景 # 创建一个与图片一样大小的白色背景
bg = Image.new(image.mode, image.size, bg_color) bg = Image.new(image.mode, image.size, bg_color)
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。 # 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
@@ -25,8 +22,7 @@ def trim_white_border(image: Image.Image):
diff = ImageChops.add(diff, diff, 2.0, -100) diff = ImageChops.add(diff, diff, 2.0, -100)
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。 # 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
bbox = diff.getbbox() bbox = diff.getbbox()
if bbox: return image.crop(bbox) if bbox else image
return image.crop(bbox)
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: