support sai model spec

This commit is contained in:
Kohya S
2023-08-06 21:50:05 +09:00
parent cd54af019a
commit c142dadb46
15 changed files with 746 additions and 64 deletions

View File

@@ -5,9 +5,11 @@
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec
import library.model_util as model_util
import library.sdxl_model_util as sdxl_model_util
import lora
@@ -197,6 +199,13 @@ def svd(args):
"ss_network_args": json.dumps(net_kwargs),
}
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, False, True, False, time.time(), title=title
)
metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
@@ -243,6 +252,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
return parser