mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fixes
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import model_util
|
from library import model_util
|
||||||
|
import library.train_util as train_util
|
||||||
import argparse
|
import argparse
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
import torch
|
import torch
|
||||||
@@ -18,14 +19,14 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||||||
def interrogate(args):
|
def interrogate(args):
|
||||||
# いろいろ準備する
|
# いろいろ準備する
|
||||||
print(f"loading SD model: {args.sd_model}")
|
print(f"loading SD model: {args.sd_model}")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, args.sd_model, DEVICE)
|
||||||
|
|
||||||
print(f"loading LoRA: {args.model}")
|
print(f"loading LoRA: {args.model}")
|
||||||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||||
|
|
||||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||||
has_te_weight = False
|
has_te_weight = False
|
||||||
for key in network.weights_sd.keys():
|
for key in weights_sd.keys():
|
||||||
if 'lora_te' in key:
|
if 'lora_te' in key:
|
||||||
has_te_weight = True
|
has_te_weight = True
|
||||||
break
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user