diff --git a/train_network.py b/train_network.py index 6d98037c..1a171325 100644 --- a/train_network.py +++ b/train_network.py @@ -12,10 +12,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -428,8 +431,10 @@ class NetworkTrainer: # set top parameter requires_grad = True for gradient checkpointing works if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - else: - unet.parameters().__next__().requires_grad_(True) + + # set top parameter requires_grad = True for gradient checkpointing works + if not train_text_encoder: # train U-Net only + unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: