From 683680e5c831392eb16850cc1077f56c86169a79 Mon Sep 17 00:00:00 2001 From: A2va <49582555+A2va@users.noreply.github.com> Date: Sun, 9 Apr 2023 21:52:02 +0200 Subject: [PATCH 1/2] Fixes --- networks/lora_interrogator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 2891798b..d53f1028 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -2,6 +2,7 @@ from tqdm import tqdm from library import model_util +import library.train_util as train_util import argparse from transformers import CLIPTokenizer import torch @@ -18,14 +19,14 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def interrogate(args): # いろいろ準備する print(f"loading SD model: {args.sd_model}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + text_encoder, vae, unet, _ = train_util.load_target_model(args, args.sd_model, DEVICE) print(f"loading LoRA: {args.model}") - network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) + network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい has_te_weight = False - for key in network.weights_sd.keys(): + for key in weights_sd.keys(): if 'lora_te' in key: has_te_weight = True break From 87163cff8b220f65a480dfdd6096c1b21e8881a6 Mon Sep 17 00:00:00 2001 From: A2va <49582555+A2va@users.noreply.github.com> Date: Mon, 17 Apr 2023 09:16:07 +0200 Subject: [PATCH 2/2] Fix missing pretrained_model_name_or_path --- networks/lora_interrogator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index d53f1028..093b0e81 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -108,10 +108,8 @@ def interrogate(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--sd_model", type=str, default=None, - help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") + + train_util.add_sd_models_arguments(parser) parser.add_argument("--model", type=str, default=None, help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") parser.add_argument("--batch_size", type=int, default=16,