From 83c7e03d050fc25f47a591c4ddfe28abdabc7ae7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 22:45:28 +0900 Subject: [PATCH] Fix network_weights not working in train_network --- gen_img_diffusers.py | 2 +- networks/lora.py | 11 +++++++++++ train_network.py | 8 ++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a0469766..af83ce47 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2285,7 +2285,7 @@ def main(args): if not args.network_merge: network.apply_to(text_encoder, unet) - info = network.load_state_dict(weights_sd, False) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") if args.opt_channels_last: diff --git a/networks/lora.py b/networks/lora.py index c5372688..4e0573d0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -593,6 +593,17 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: print("enable LoRA for text encoder") diff --git a/train_network.py b/train_network.py index a7b167bf..c79b0922 100644 --- a/train_network.py +++ b/train_network.py @@ -194,14 +194,14 @@ def train(args): if network is None: return - if args.network_weights is not None: - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) - train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + print(f"load network weights from {args.network_weights}: {info}") + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable()