完成了web,ray server,重构了代码
This commit is contained in:
@@ -4,12 +4,14 @@ from typing import List
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
||||
# 输出的np.ndarray的格式为:[H, W, C](通道在第三维),通道的排列顺序为RGB
|
||||
processed_images = []
|
||||
|
||||
for path in image_paths:
|
||||
# 读取图片
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
@@ -32,6 +34,6 @@ def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
# 如果是 BGR (3通道), 转换为 RGB
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(Image.fromarray(image))
|
||||
processed_images.append(image)
|
||||
|
||||
return processed_images
|
||||
39
src/models/ocr_model/utils/inference.py
Normal file
39
src/models/ocr_model/utils/inference.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.transforms import inference_transform
|
||||
from models.ocr_model.utils.helpers import convert2rgb
|
||||
from models.globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs_path: List[str],
|
||||
use_cuda: bool,
|
||||
num_beams: int = 1,
|
||||
) -> List[str]:
|
||||
model.eval()
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
if use_cuda:
|
||||
model = model.to('cuda')
|
||||
pixel_values = pixel_values.to('cuda')
|
||||
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
return res
|
||||
@@ -7,7 +7,7 @@ from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
|
||||
from ....globals import (
|
||||
from ...globals import (
|
||||
OCR_IMG_CHANNELS,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_FIX_SIZE,
|
||||
|
||||
Reference in New Issue
Block a user