Support onnx runtime

This commit is contained in:
三洋三洋
2024-06-22 21:51:51 +08:00
parent 8da3fd7418
commit 9638c0030d
5 changed files with 65 additions and 26 deletions

View File

@@ -20,7 +20,9 @@ def inference(
) -> List[str]:
if imgs == []:
return []
model.eval()
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
@@ -29,7 +31,9 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
model = model.to(accelerator)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(