diff --git a/train_network.py b/train_network.py index b482c80a..6f41d199 100644 --- a/train_network.py +++ b/train_network.py @@ -401,6 +401,8 @@ class NetworkTrainer: text_encoder, network, optimizer, train_dataloader, lr_scheduler ) text_encoders = [text_encoder] + + unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator else: network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler