From b8ad17902f91d0293daf6a685224d2bb59f9b301 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 8 Feb 2023 23:09:59 +0800 Subject: [PATCH] fix get_hidden_states expected scalar Error again --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f247c74e..928ad321 100644 --- a/train_network.py +++ b/train_network.py @@ -257,7 +257,7 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train()