This commit is contained in:
A2va
2023-04-09 21:52:02 +02:00
parent 5050971ac6
commit 683680e5c8

View File

@@ -2,6 +2,7 @@
from tqdm import tqdm from tqdm import tqdm
from library import model_util from library import model_util
import library.train_util as train_util
import argparse import argparse
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import torch import torch
@@ -18,14 +19,14 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args): def interrogate(args):
# いろいろ準備する # いろいろ準備する
print(f"loading SD model: {args.sd_model}") print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) text_encoder, vae, unet, _ = train_util.load_target_model(args, args.sd_model, DEVICE)
print(f"loading LoRA: {args.model}") print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい # text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False has_te_weight = False
for key in network.weights_sd.keys(): for key in weights_sd.keys():
if 'lora_te' in key: if 'lora_te' in key:
has_te_weight = True has_te_weight = True
break break