mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
restore sd_model arg for backward compat
This commit is contained in:
@@ -17,9 +17,13 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||||||
|
|
||||||
|
|
||||||
def interrogate(args):
|
def interrogate(args):
|
||||||
|
weights_dtype = torch.float16
|
||||||
|
|
||||||
# いろいろ準備する
|
# いろいろ準備する
|
||||||
print(f"loading SD model: {args.sd_model}")
|
print(f"loading SD model: {args.sd_model}")
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, args.sd_model, DEVICE)
|
args.pretrained_model_name_or_path = args.sd_model
|
||||||
|
args.vae = None
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
|
||||||
|
|
||||||
print(f"loading LoRA: {args.model}")
|
print(f"loading LoRA: {args.model}")
|
||||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||||
@@ -41,9 +45,9 @@ def interrogate(args):
|
|||||||
else:
|
else:
|
||||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||||||
|
|
||||||
text_encoder.to(DEVICE)
|
text_encoder.to(DEVICE, dtype=weights_dtype)
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
unet.to(DEVICE)
|
unet.to(DEVICE, dtype=weights_dtype)
|
||||||
unet.eval() # U-Netは呼び出さないので不要だけど
|
unet.eval() # U-Netは呼び出さないので不要だけど
|
||||||
|
|
||||||
# トークンをひとつひとつ当たっていく
|
# トークンをひとつひとつ当たっていく
|
||||||
@@ -79,9 +83,14 @@ def interrogate(args):
|
|||||||
orig_embs = get_all_embeddings(text_encoder)
|
orig_embs = get_all_embeddings(text_encoder)
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||||
network.to(DEVICE)
|
info = network.load_state_dict(weights_sd, strict=False)
|
||||||
|
print(f"Loading LoRA weights: {info}")
|
||||||
|
|
||||||
|
network.to(DEVICE, dtype=weights_dtype)
|
||||||
network.eval()
|
network.eval()
|
||||||
|
|
||||||
|
del unet
|
||||||
|
|
||||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||||
print("get text encoder embeddings with lora.")
|
print("get text encoder embeddings with lora.")
|
||||||
lora_embs = get_all_embeddings(text_encoder)
|
lora_embs = get_all_embeddings(text_encoder)
|
||||||
@@ -109,7 +118,10 @@ def interrogate(args):
|
|||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||||
|
parser.add_argument("--sd_model", type=str, default=None,
|
||||||
|
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
||||||
parser.add_argument("--model", type=str, default=None,
|
parser.add_argument("--model", type=str, default=None,
|
||||||
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
||||||
parser.add_argument("--batch_size", type=int, default=16,
|
parser.add_argument("--batch_size", type=int, default=16,
|
||||||
|
|||||||
Reference in New Issue
Block a user