diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 093b0e81..beb25181 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -17,9 +17,13 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def interrogate(args): + weights_dtype = torch.float16 + # いろいろ準備する print(f"loading SD model: {args.sd_model}") - text_encoder, vae, unet, _ = train_util.load_target_model(args, args.sd_model, DEVICE) + args.pretrained_model_name_or_path = args.sd_model + args.vae = None + text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE) print(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) @@ -41,9 +45,9 @@ def interrogate(args): else: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) - text_encoder.to(DEVICE) + text_encoder.to(DEVICE, dtype=weights_dtype) text_encoder.eval() - unet.to(DEVICE) + unet.to(DEVICE, dtype=weights_dtype) unet.eval() # U-Netは呼び出さないので不要だけど # トークンをひとつひとつ当たっていく @@ -79,9 +83,14 @@ def interrogate(args): orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) - network.to(DEVICE) + info = network.load_state_dict(weights_sd, strict=False) + print(f"Loading LoRA weights: {info}") + + network.to(DEVICE, dtype=weights_dtype) network.eval() + del unet + print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") print("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) @@ -109,7 +118,10 @@ def interrogate(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") parser.add_argument("--model", type=str, default=None, help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") parser.add_argument("--batch_size", type=int, default=16,