fix a typo

This commit is contained in:
Isotr0py
2023-10-08 20:49:03 +08:00
parent 70fe7e18be
commit b8b84021e5

View File

@@ -117,7 +117,7 @@ def main(args):
)
args.batch_size = batch_size
ort_sess = ort.InferenceSession(
model.SerializeToString(),
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
@@ -154,7 +154,7 @@ def main(args):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
else:
probs = model(imgs, training=False)
probs = probs.numpy()