diff --git a/train_network.py b/train_network.py index a4125e9f..edd3ff94 100644 --- a/train_network.py +++ b/train_network.py @@ -145,7 +145,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings(