updated API usage (supports remote calls)
This commit is contained in:
@@ -3,9 +3,12 @@ import requests
|
|||||||
url = "http://127.0.0.1:8000/predict"
|
url = "http://127.0.0.1:8000/predict"
|
||||||
|
|
||||||
img_path = "/your/image/path/"
|
img_path = "/your/image/path/"
|
||||||
|
with open(img_path, 'rb') as img:
|
||||||
|
files = {'img': img}
|
||||||
|
response = requests.post(url, files=files)
|
||||||
|
|
||||||
data = {"img_path": img_path}
|
# data = {"img_path": img_path}
|
||||||
|
|
||||||
response = requests.post(url, json=data)
|
# response = requests.post(url, json=data)
|
||||||
|
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
from models.ocr_model.utils.transforms import inference_transform
|
from models.ocr_model.utils.transforms import inference_transform
|
||||||
@@ -12,12 +13,15 @@ from models.globals import MAX_TOKEN_SIZE
|
|||||||
def inference(
|
def inference(
|
||||||
model: TexTeller,
|
model: TexTeller,
|
||||||
tokenizer: RobertaTokenizerFast,
|
tokenizer: RobertaTokenizerFast,
|
||||||
imgs_path: List[str],
|
imgs_path: Union[List[str], List[np.ndarray]],
|
||||||
use_cuda: bool,
|
use_cuda: bool,
|
||||||
num_beams: int = 1,
|
num_beams: int = 1,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
model.eval()
|
model.eval()
|
||||||
imgs = convert2rgb(imgs_path)
|
if isinstance(imgs_path[0], str):
|
||||||
|
imgs = convert2rgb(imgs_path)
|
||||||
|
else: # already numpy array(rgb format)
|
||||||
|
imgs = imgs_path
|
||||||
imgs = inference_transform(imgs)
|
imgs = inference_transform(imgs)
|
||||||
pixel_values = torch.stack(imgs)
|
pixel_values = torch.stack(imgs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from ray import serve
|
from ray import serve
|
||||||
@@ -51,8 +53,8 @@ class TexTellerServer:
|
|||||||
|
|
||||||
self.model = self.model.to('cuda') if use_cuda else self.model
|
self.model = self.model.to('cuda') if use_cuda else self.model
|
||||||
|
|
||||||
def predict(self, image_path: str) -> str:
|
def predict(self, image_nparray) -> str:
|
||||||
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
|
return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0]
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment()
|
@serve.deployment()
|
||||||
@@ -61,9 +63,12 @@ class Ingress:
|
|||||||
self.texteller_server = texteller_server
|
self.texteller_server = texteller_server
|
||||||
|
|
||||||
async def __call__(self, request: Request) -> str:
|
async def __call__(self, request: Request) -> str:
|
||||||
msg = await request.json()
|
form = await request.form()
|
||||||
img_path: str = msg['img_path']
|
img_rb = await form['img'].read()
|
||||||
pred = await self.texteller_server.predict.remote(img_path)
|
|
||||||
|
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||||
|
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||||
|
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user