Support onnx runtime

This commit is contained in:
三洋三洋
2024-06-22 21:51:51 +08:00
parent 8da3fd7418
commit 9638c0030d
5 changed files with 65 additions and 26 deletions

View File

@@ -6,7 +6,7 @@ det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/" img_path = "/your/image/path/"
with open(img_path, 'rb') as img: with open(img_path, 'rb') as img:
files = {'img': img} files = {'img': img}
response = requests.post(det_server_url, files=files) response = requests.post(rec_server_url, files=files)
# response = requests.post(rec_server_url, files=files) # response = requests.post(det_server_url, files=files)
print(response.text) print(response.text)

View File

@@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from optimum.onnxruntime import ORTModelForVision2Seq
from ...globals import ( from ...globals import (
VOCAB_SIZE, VOCAB_SIZE,
@@ -10,25 +11,29 @@ from ...globals import (
from transformers import ( from transformers import (
RobertaTokenizerFast, RobertaTokenizerFast,
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
VisionEncoderDecoderConfig, VisionEncoderDecoderConfig
) )
class TexTeller(VisionEncoderDecoderModel): class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller' REPO_NAME = 'OleehyO/TexTeller'
def __init__(self): def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json") config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small')
config.encoder.image_size = FIXED_IMG_SIZE config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE config.decoder.vocab_size=VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE config.decoder.max_position_embeddings=MAX_TOKEN_SIZE
super().__init__(config=config) super().__init__(config=config)
@classmethod @classmethod
def from_pretrained(cls, model_path: str = None): def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
if model_path is None or model_path == 'default': if model_path is None or model_path == 'default':
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME) return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else:
use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider")
model_path = Path(model_path).resolve() model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path)) return VisionEncoderDecoderModel.from_pretrained(str(model_path))

View File

@@ -20,6 +20,8 @@ def inference(
) -> List[str]: ) -> List[str]:
if imgs == []: if imgs == []:
return [] return []
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval() model.eval()
if isinstance(imgs[0], str): if isinstance(imgs[0], str):
imgs = convert2rgb(imgs) imgs = convert2rgb(imgs)
@@ -29,6 +31,8 @@ def inference(
imgs = inference_transform(imgs) imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs) pixel_values = torch.stack(imgs)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator) model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator) pixel_values = pixel_values.to(accelerator)

View File

@@ -1,3 +1,4 @@
import sys
import argparse import argparse
import tempfile import tempfile
import time import time
@@ -17,6 +18,10 @@ from models.det_model.inference import PredictConfig
from models.ocr_model.utils.to_katex import to_katex from models.ocr_model.utils.to_katex import to_katex
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-ckpt', '--checkpoint_dir', type=str '-ckpt', '--checkpoint_dir', type=str
@@ -31,6 +36,7 @@ parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--inference-mode', type=str, default='cpu') parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1) parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
args = parser.parse_args() args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda': if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
@@ -41,7 +47,7 @@ if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
num_replicas=args.num_replicas, num_replicas=args.num_replicas,
ray_actor_options={ ray_actor_options={
"num_cpus": args.ncpu_per_replica, "num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica "num_gpus": args.ngpu_per_replica * 1.0 / 2
} }
) )
class TexTellerRecServer: class TexTellerRecServer:
@@ -50,13 +56,15 @@ class TexTellerRecServer:
checkpoint_path: str, checkpoint_path: str,
tokenizer_path: str, tokenizer_path: str,
inf_mode: str = 'cpu', inf_mode: str = 'cpu',
use_onnx: bool = False,
num_beams: int = 1 num_beams: int = 1
) -> None: ) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path) self.model = TexTeller.from_pretrained(checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path) self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.inf_mode = inf_mode self.inf_mode = inf_mode
self.num_beams = num_beams self.num_beams = num_beams
if not use_onnx:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str: def predict(self, image_nparray) -> str:
@@ -65,14 +73,28 @@ class TexTellerRecServer:
accelerator=self.inf_mode, num_beams=self.num_beams accelerator=self.inf_mode, num_beams=self.num_beams
)[0]) )[0])
@serve.deployment(
@serve.deployment(num_replicas=args.num_replicas) num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
"runtime_env": {
"env_vars": {
"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"
}
}
},
)
class TexTellerDetServer: class TexTellerDetServer:
def __init__( def __init__(
self self,
inf_mode='cpu'
): ):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") self.latex_det_model = InferenceSession(
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider']
)
async def predict(self, image_nparray) -> str: async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
@@ -120,11 +142,12 @@ if __name__ == '__main__':
rec_server = TexTellerRecServer.bind( rec_server = TexTellerRecServer.bind(
ckpt_dir, tknz_dir, ckpt_dir, tknz_dir,
inf_mode=args.inference_mode, inf_mode=args.inference_mode,
use_onnx=args.onnx,
num_beams=args.num_beams num_beams=args.num_beams
) )
det_server = None det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists(): if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind() det_server = TexTellerDetServer.bind(args.inference_mode)
ingress = Ingress.bind(det_server, rec_server) ingress = Ingress.bind(det_server, rec_server)
# ingress_handle = serve.run(ingress, route_prefix="/predict") # ingress_handle = serve.run(ingress, route_prefix="/predict")

View File

@@ -50,17 +50,20 @@ fail_gif_html = '''
''' '''
@st.cache_resource @st.cache_resource
def get_texteller(): def get_texteller(use_onnx, accelerator):
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'], use_onnx=use_onnx, onnx_provider=accelerator)
@st.cache_resource @st.cache_resource
def get_tokenizer(): def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
@st.cache_resource @st.cache_resource
def get_det_models(): def get_det_models(accelerator):
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") latex_det_model = InferenceSession(
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
providers=['CUDAExecutionProvider'] if accelerator == 'cuda' else ['CPUExecutionProvider']
)
return infer_config, latex_det_model return infer_config, latex_det_model
@st.cache_resource() @st.cache_resource()
@@ -141,18 +144,22 @@ with st.sidebar:
on_change=change_side_bar on_change=change_side_bar
) )
st.markdown("## Seepup Setting")
use_onnx = st.toggle("ONNX Runtime ")
############################## </sidebar> ############################## ############################## </sidebar> ##############################
################################ <page> ################################ ################################ <page> ################################
texteller = get_texteller() texteller = get_texteller(use_onnx, accelerator)
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
latex_rec_models = [texteller, tokenizer] latex_rec_models = [texteller, tokenizer]
if inf_mode == "Paragraph recognition": if inf_mode == "Paragraph recognition":
infer_config, latex_det_model = get_det_models() infer_config, latex_det_model = get_det_models(accelerator)
lang_ocr_models = get_ocr_models(accelerator) lang_ocr_models = get_ocr_models(accelerator)
st.markdown(html_string, unsafe_allow_html=True) st.markdown(html_string, unsafe_allow_html=True)