97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
import os
|
|
import argparse
|
|
import glob
|
|
import subprocess
|
|
|
|
import onnxruntime
|
|
from pathlib import Path
|
|
|
|
from models.det_model.inference import PredictConfig, predict_image
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
|
|
)
|
|
parser.add_argument(
|
|
'--onnx_file',
|
|
type=str,
|
|
help="onnx model file path",
|
|
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
|
)
|
|
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
|
parser.add_argument("--image_file", type=str)
|
|
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
|
parser.add_argument(
|
|
'--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
|
|
)
|
|
|
|
|
|
def get_test_images(infer_dir, infer_img):
|
|
"""
|
|
Get image path list in TEST mode
|
|
"""
|
|
assert (
|
|
infer_img is not None or infer_dir is not None
|
|
), "--image_file or --image_dir should be set"
|
|
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
|
|
assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
|
|
|
|
# infer_img has a higher priority
|
|
if infer_img and os.path.isfile(infer_img):
|
|
return [infer_img]
|
|
|
|
images = set()
|
|
infer_dir = os.path.abspath(infer_dir)
|
|
assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
|
|
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
|
exts += [ext.upper() for ext in exts]
|
|
for ext in exts:
|
|
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
|
images = list(images)
|
|
|
|
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
|
print("Found {} inference images in total.".format(len(images)))
|
|
|
|
return images
|
|
|
|
|
|
def download_file(url, filename):
|
|
print(f"Downloading {filename}...")
|
|
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
|
print("Download complete.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cur_path = os.getcwd()
|
|
script_dirpath = Path(__file__).resolve().parent
|
|
os.chdir(script_dirpath)
|
|
|
|
FLAGS = parser.parse_args()
|
|
|
|
if not os.path.exists(FLAGS.infer_cfg):
|
|
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
|
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
|
|
|
if not os.path.exists(FLAGS.onnx_file):
|
|
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
|
download_file(onnx_file_url, FLAGS.onnx_file)
|
|
|
|
# load image list
|
|
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
|
|
|
if FLAGS.use_gpu:
|
|
predictor = onnxruntime.InferenceSession(
|
|
FLAGS.onnx_file, providers=['CUDAExecutionProvider']
|
|
)
|
|
else:
|
|
predictor = onnxruntime.InferenceSession(
|
|
FLAGS.onnx_file, providers=['CPUExecutionProvider']
|
|
)
|
|
# load infer config
|
|
infer_config = PredictConfig(FLAGS.infer_cfg)
|
|
|
|
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
|
|
|
os.chdir(cur_path)
|