From 3527a4af4785523433c2f06b61e17a1b7ae9daca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Tue, 27 Feb 2024 07:13:36 +0000 Subject: [PATCH] updated API usage (supports remote calls) --- src/client_demo.py | 7 +++++-- src/models/ocr_model/utils/inference.py | 10 +++++++--- src/server.py | 15 ++++++++++----- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/client_demo.py b/src/client_demo.py index 5b68e2d..8ea236c 100644 --- a/src/client_demo.py +++ b/src/client_demo.py @@ -3,9 +3,12 @@ import requests url = "http://127.0.0.1:8000/predict" img_path = "/your/image/path/" +with open(img_path, 'rb') as img: + files = {'img': img} + response = requests.post(url, files=files) -data = {"img_path": img_path} +# data = {"img_path": img_path} -response = requests.post(url, json=data) +# response = requests.post(url, json=data) print(response.text) diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index eaa4b52..5dcaf9f 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -1,7 +1,8 @@ import torch +import numpy as np from transformers import RobertaTokenizerFast, GenerationConfig -from typing import List +from typing import List, Union from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.utils.transforms import inference_transform @@ -12,12 +13,15 @@ from models.globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, - imgs_path: List[str], + imgs_path: Union[List[str], List[np.ndarray]], use_cuda: bool, num_beams: int = 1, ) -> List[str]: model.eval() - imgs = convert2rgb(imgs_path) + if isinstance(imgs_path[0], str): + imgs = convert2rgb(imgs_path) + else: # already numpy array(rgb format) + imgs = imgs_path imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) diff --git a/src/server.py b/src/server.py index e838f27..7124134 100644 --- a/src/server.py +++ b/src/server.py @@ -1,5 +1,7 @@ import argparse import time +import numpy as np +import cv2 from starlette.requests import Request from ray import serve @@ -51,8 +53,8 @@ class TexTellerServer: self.model = self.model.to('cuda') if use_cuda else self.model - def predict(self, image_path: str) -> str: - return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0] + def predict(self, image_nparray) -> str: + return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0] @serve.deployment() @@ -61,9 +63,12 @@ class Ingress: self.texteller_server = texteller_server async def __call__(self, request: Request) -> str: - msg = await request.json() - img_path: str = msg['img_path'] - pred = await self.texteller_server.predict.remote(img_path) + form = await request.form() + img_rb = await form['img'].read() + + img_nparray = np.frombuffer(img_rb, np.uint8) + img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) + pred = await self.texteller_server.predict.remote(img_nparray) return pred