updated API usage (supports remote calls)

This commit is contained in:
三洋三洋
2024-02-27 07:13:36 +00:00
parent b4537944d0
commit 3527a4af47
3 changed files with 22 additions and 10 deletions

View File

@@ -3,9 +3,12 @@ import requests
url = "http://127.0.0.1:8000/predict" url = "http://127.0.0.1:8000/predict"
img_path = "/your/image/path/" img_path = "/your/image/path/"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(url, files=files)
data = {"img_path": img_path} # data = {"img_path": img_path}
response = requests.post(url, json=data) # response = requests.post(url, json=data)
print(response.text) print(response.text)

View File

@@ -1,7 +1,8 @@
import torch import torch
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List from typing import List, Union
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform from models.ocr_model.utils.transforms import inference_transform
@@ -12,12 +13,15 @@ from models.globals import MAX_TOKEN_SIZE
def inference( def inference(
model: TexTeller, model: TexTeller,
tokenizer: RobertaTokenizerFast, tokenizer: RobertaTokenizerFast,
imgs_path: List[str], imgs_path: Union[List[str], List[np.ndarray]],
use_cuda: bool, use_cuda: bool,
num_beams: int = 1, num_beams: int = 1,
) -> List[str]: ) -> List[str]:
model.eval() model.eval()
imgs = convert2rgb(imgs_path) if isinstance(imgs_path[0], str):
imgs = convert2rgb(imgs_path)
else: # already numpy array(rgb format)
imgs = imgs_path
imgs = inference_transform(imgs) imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs) pixel_values = torch.stack(imgs)

View File

@@ -1,5 +1,7 @@
import argparse import argparse
import time import time
import numpy as np
import cv2
from starlette.requests import Request from starlette.requests import Request
from ray import serve from ray import serve
@@ -51,8 +53,8 @@ class TexTellerServer:
self.model = self.model.to('cuda') if use_cuda else self.model self.model = self.model.to('cuda') if use_cuda else self.model
def predict(self, image_path: str) -> str: def predict(self, image_nparray) -> str:
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0] return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0]
@serve.deployment() @serve.deployment()
@@ -61,9 +63,12 @@ class Ingress:
self.texteller_server = texteller_server self.texteller_server = texteller_server
async def __call__(self, request: Request) -> str: async def __call__(self, request: Request) -> str:
msg = await request.json() form = await request.form()
img_path: str = msg['img_path'] img_rb = await form['img'].read()
pred = await self.texteller_server.predict.remote(img_path)
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
pred = await self.texteller_server.predict.remote(img_nparray)
return pred return pred