mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix a typo
This commit is contained in:
@@ -117,7 +117,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
args.batch_size = batch_size
|
args.batch_size = batch_size
|
||||||
ort_sess = ort.InferenceSession(
|
ort_sess = ort.InferenceSession(
|
||||||
model.SerializeToString(),
|
onnx_path,
|
||||||
providers=["CUDAExecutionProvider"]
|
providers=["CUDAExecutionProvider"]
|
||||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||||
else ["CPUExecutionProvider"],
|
else ["CPUExecutionProvider"],
|
||||||
@@ -154,7 +154,7 @@ def main(args):
|
|||||||
imgs = np.array([im for _, im in path_imgs])
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy
|
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||||
else:
|
else:
|
||||||
probs = model(imgs, training=False)
|
probs = model(imgs, training=False)
|
||||||
probs = probs.numpy()
|
probs = probs.numpy()
|
||||||
|
|||||||
Reference in New Issue
Block a user