From 569ca72fc4cda2f4ce30e43b1c62989e79e3c3b3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Nov 2023 11:59:30 -0500 Subject: [PATCH] Set grad enabled if is_train and train_text_encoder We only want to be enabling grad if we are training. --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index a4125e9f..edd3ff94 100644 --- a/train_network.py +++ b/train_network.py @@ -145,7 +145,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings(