fix get_hidden_states expected scalar Error again

This commit is contained in:
Isotr0py
2023-02-08 23:09:59 +08:00
committed by GitHub
parent 9a9ac79edf
commit b8ad17902f

View File

@@ -257,7 +257,7 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()