1) 加入了推理代码; 2) 整理了其他代码
This commit is contained in:
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .model.TexTeller import TexTeller
|
||||||
|
from .utils.transforms import inference_transform
|
||||||
|
from ...globals import MAX_TOKEN_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def png2jpg(imgs: List[Image.Image]):
|
||||||
|
imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")]
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
|
def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||||
|
imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs
|
||||||
|
imgs = inference_transform(imgs)
|
||||||
|
pixel_values = torch.stack(imgs)
|
||||||
|
|
||||||
|
generate_config = GenerationConfig(
|
||||||
|
max_new_tokens=MAX_TOKEN_SIZE,
|
||||||
|
num_beams=3,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||||
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
inference()
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
|
from PIL import Image
|
||||||
|
|
||||||
from ....globals import (
|
from ....globals import (
|
||||||
VOCAB_SIZE,
|
VOCAB_SIZE,
|
||||||
OCR_IMG_SIZE,
|
OCR_IMG_SIZE,
|
||||||
OCR_IMG_CHANNELS
|
OCR_IMG_CHANNELS,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
ViTConfig,
|
ViTConfig,
|
||||||
ViTModel,
|
ViTModel,
|
||||||
|
|
||||||
TrOCRConfig,
|
TrOCRConfig,
|
||||||
TrOCRForCausalLM,
|
TrOCRForCausalLM,
|
||||||
|
|
||||||
RobertaTokenizerFast,
|
RobertaTokenizerFast,
|
||||||
|
VisionEncoderDecoderModel,
|
||||||
VisionEncoderDecoderModel
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -38,9 +37,18 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
texteller = TexTeller()
|
# texteller = TexTeller()
|
||||||
tokenizer = texteller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
from ..inference import inference
|
||||||
foo = ["Hello, my name is LHY.", "I am a researcher at the University of Science and Technology of China."]
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
|
||||||
bar = tokenizer(foo, return_special_tokens_mask=True)
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
|
||||||
|
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
|
||||||
|
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
|
||||||
|
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
|
||||||
|
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
|
||||||
|
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
|
||||||
|
|
||||||
|
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
|
||||||
pause = 1
|
pause = 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import datasets
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from PIL import ImageChops, Image
|
from PIL import ImageChops, Image
|
||||||
from typing import Any, Dict, List
|
from typing import List
|
||||||
|
|
||||||
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
||||||
|
|
||||||
@@ -11,12 +10,10 @@ from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN
|
|||||||
def trim_white_border(image: Image.Image):
|
def trim_white_border(image: Image.Image):
|
||||||
if image.mode == 'RGB':
|
if image.mode == 'RGB':
|
||||||
bg_color = (255, 255, 255)
|
bg_color = (255, 255, 255)
|
||||||
elif image.mode == 'RGBA':
|
|
||||||
bg_color = (255, 255, 255, 255)
|
|
||||||
elif image.mode == 'L':
|
elif image.mode == 'L':
|
||||||
bg_color = 255
|
bg_color = 255
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported image mode")
|
raise ValueError("Only support RGB or L mode")
|
||||||
# 创建一个与图片一样大小的白色背景
|
# 创建一个与图片一样大小的白色背景
|
||||||
bg = Image.new(image.mode, image.size, bg_color)
|
bg = Image.new(image.mode, image.size, bg_color)
|
||||||
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
||||||
@@ -25,8 +22,7 @@ def trim_white_border(image: Image.Image):
|
|||||||
diff = ImageChops.add(diff, diff, 2.0, -100)
|
diff = ImageChops.add(diff, diff, 2.0, -100)
|
||||||
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
||||||
bbox = diff.getbbox()
|
bbox = diff.getbbox()
|
||||||
if bbox:
|
return image.crop(bbox) if bbox else image
|
||||||
return image.crop(bbox)
|
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
|
|||||||
Reference in New Issue
Block a user