Fix full_fp16 and clip_skip==2 is not working

This commit is contained in:
Kohya S
2023-01-08 18:49:34 +09:00
parent 80af4c0c42
commit 82e585cf01
2 changed files with 5 additions and 3 deletions

View File

@@ -230,7 +230,8 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)