Merge pull request #1525 from Akegarasu/sd3

make guidance_scale keep float in args
This commit is contained in:
Kohya S.
2024-08-29 22:08:57 +09:00
committed by GitHub

View File

@@ -324,7 +324,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
if args.gradient_checkpointing: