Support onnx runtime
This commit is contained in:
@@ -20,7 +20,9 @@ def inference(
|
||||
) -> List[str]:
|
||||
if imgs == []:
|
||||
return []
|
||||
model.eval()
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, turn model.eval()
|
||||
model.eval()
|
||||
if isinstance(imgs[0], str):
|
||||
imgs = convert2rgb(imgs)
|
||||
else: # already numpy array(rgb format)
|
||||
@@ -29,7 +31,9 @@ def inference(
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
model = model.to(accelerator)
|
||||
if hasattr(model, 'eval'):
|
||||
# not onnx session, move weights to device
|
||||
model = model.to(accelerator)
|
||||
pixel_values = pixel_values.to(accelerator)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
|
||||
Reference in New Issue
Block a user