make guidance_scale keep float in args

This commit is contained in:
Akegarasu
2024-08-29 14:50:29 +08:00
parent a61cf73a5c
commit 6c0e8a5a17

View File

@@ -324,7 +324,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
if args.gradient_checkpointing: