mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix a typo
This commit is contained in:
@@ -117,7 +117,7 @@ def main(args):
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"],
|
||||
@@ -154,7 +154,7 @@ def main(args):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
if args.onnx:
|
||||
probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
Reference in New Issue
Block a user