diff --git a/train_network.py b/train_network.py index f247c74e..928ad321 100644 --- a/train_network.py +++ b/train_network.py @@ -257,7 +257,7 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train()