fix a typo

This commit is contained in:
Isotr0py
2023-10-08 20:49:03 +08:00
parent 70fe7e18be
commit b8b84021e5

View File

@@ -117,7 +117,7 @@ def main(args):
) )
args.batch_size = batch_size args.batch_size = batch_size
ort_sess = ort.InferenceSession( ort_sess = ort.InferenceSession(
model.SerializeToString(), onnx_path,
providers=["CUDAExecutionProvider"] providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers() if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"], else ["CPUExecutionProvider"],
@@ -154,7 +154,7 @@ def main(args):
imgs = np.array([im for _, im in path_imgs]) imgs = np.array([im for _, im in path_imgs])
if args.onnx: if args.onnx:
probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
else: else:
probs = model(imgs, training=False) probs = model(imgs, training=False)
probs = probs.numpy() probs = probs.numpy()