diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index e025c74e..19c63acf 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1995,11 +1995,12 @@ def main(args): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - from safetensors.torch import safe_open - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + if model_util.is_safetensors(network_weight): + from safetensors.torch import safe_open + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) else: diff --git a/train_network.py b/train_network.py index 88014ddb..393d8f9d 100644 --- a/train_network.py +++ b/train_network.py @@ -245,7 +245,7 @@ def train(args): "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": args.network_module, "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not use this value + "ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), @@ -264,6 +264,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_training_comment": args.training_comment # will not be updated after training } # uncomment if another network is added @@ -448,12 +449,14 @@ if __name__ == '__main__': parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--network_alpha", type=float, default=1, - help='alpha for LoRA weight scaling, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、0で調整なし(旧バージョンと同じ)') + help='alpha for LoRA weight scaling, default 1, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、デフォルト1、0で調整なし(旧バージョンと同じ)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する") + parser.add_argument("--training_comment", type=str, default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() train(args)