fix device

This commit is contained in:
minux302
2024-11-29 14:40:38 +00:00
4 changed files with 7 additions and 6 deletions

View File

@@ -266,7 +266,7 @@ def train(args):
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors)
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing: