From 14125da26fa3a07e46e33e5431a31f6a728f1764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sun, 28 Jan 2024 14:03:42 +0000 Subject: [PATCH] =?UTF-8?q?1)=20=E5=8A=A0=E5=85=A5=E4=BA=86=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E4=BB=A3=E7=A0=81;=202)=20=E6=95=B4=E7=90=86=E4=BA=86?= =?UTF-8?q?=E5=85=B6=E4=BB=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/ocr_model/inference.py | 32 ++++++++++++++++++++++++ src/models/ocr_model/model/TexTeller.py | 26 ++++++++++++------- src/models/ocr_model/utils/preprocess.py | 2 -- src/models/ocr_model/utils/transforms.py | 10 +++----- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/models/ocr_model/inference.py b/src/models/ocr_model/inference.py index e69de29..bc747ac 100644 --- a/src/models/ocr_model/inference.py +++ b/src/models/ocr_model/inference.py @@ -0,0 +1,32 @@ +import torch +from transformers import RobertaTokenizerFast, GenerationConfig +from PIL import Image +from typing import List + +from .model.TexTeller import TexTeller +from .utils.transforms import inference_transform +from ...globals import MAX_TOKEN_SIZE + + +def png2jpg(imgs: List[Image.Image]): + imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")] + return imgs + + +def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]: + imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs + imgs = inference_transform(imgs) + pixel_values = torch.stack(imgs) + + generate_config = GenerationConfig( + max_new_tokens=MAX_TOKEN_SIZE, + num_beams=3, + do_sample=False + ) + pred = model.generate(pixel_values, generation_config=generate_config) + res = tokenizer.batch_decode(pred, skip_special_tokens=True) + return res + + +if __name__ == '__main__': + inference() diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 19d98cd..7f38ec3 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -1,19 +1,18 @@ +from PIL import Image + from ....globals import ( VOCAB_SIZE, OCR_IMG_SIZE, - OCR_IMG_CHANNELS + OCR_IMG_CHANNELS, ) from transformers import ( ViTConfig, ViTModel, - TrOCRConfig, TrOCRForCausalLM, - RobertaTokenizerFast, - - VisionEncoderDecoderModel + VisionEncoderDecoderModel, ) @@ -38,9 +37,18 @@ class TexTeller(VisionEncoderDecoderModel): if __name__ == "__main__": - texteller = TexTeller() - tokenizer = texteller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') - foo = ["Hello, my name is LHY.", "I am a researcher at the University of Science and Technology of China."] - bar = tokenizer(foo, return_special_tokens_mask=True) + # texteller = TexTeller() + from ..inference import inference + model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500') + tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') + + img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png') + img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png') + img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png') + img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png') + img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png') + img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png') + + res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer) pause = 1 diff --git a/src/models/ocr_model/utils/preprocess.py b/src/models/ocr_model/utils/preprocess.py index 03dd6a5..1892e7a 100644 --- a/src/models/ocr_model/utils/preprocess.py +++ b/src/models/ocr_model/utils/preprocess.py @@ -1,6 +1,4 @@ import torch -import datasets - from datasets import load_dataset from functools import partial diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index d5a7f92..f7d6578 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -1,9 +1,8 @@ import torch -import torchvision from torchvision.transforms import v2 from PIL import ImageChops, Image -from typing import Any, Dict, List +from typing import List from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD @@ -11,12 +10,10 @@ from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN def trim_white_border(image: Image.Image): if image.mode == 'RGB': bg_color = (255, 255, 255) - elif image.mode == 'RGBA': - bg_color = (255, 255, 255, 255) elif image.mode == 'L': bg_color = 255 else: - raise ValueError("Unsupported image mode") + raise ValueError("Only support RGB or L mode") # 创建一个与图片一样大小的白色背景 bg = Image.new(image.mode, image.size, bg_color) # 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。 @@ -25,8 +22,7 @@ def trim_white_border(image: Image.Image): diff = ImageChops.add(diff, diff, 2.0, -100) # 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。 bbox = diff.getbbox() - if bbox: - return image.crop(bbox) + return image.crop(bbox) if bbox else image def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: