mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sai model spec
This commit is contained in:
@@ -5,9 +5,11 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec
|
||||
import library.model_util as model_util
|
||||
import library.sdxl_model_util as sdxl_model_util
|
||||
import lora
|
||||
@@ -197,6 +199,13 @@ def svd(args):
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not args.no_metadata:
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None, args.v2, args.v_parameterization, False, True, False, time.time(), title=title
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
@@ -243,6 +252,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user