mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support sai model spec
This commit is contained in:
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
|
||||
class NetworkTrainer:
|
||||
def __init__(self):
|
||||
self.vae_scale_factor = 0.18215
|
||||
self.is_sdxl = False
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(
|
||||
@@ -217,7 +218,7 @@ class NetworkTrainer:
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
@@ -401,7 +402,7 @@ class NetworkTrainer:
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
@@ -660,16 +661,8 @@ class NetworkTrainer:
|
||||
metadata = {k: str(v) for k, v in metadata.items()}
|
||||
|
||||
# make minimum metadata for filtering
|
||||
minimum_keys = [
|
||||
"ss_v2",
|
||||
"ss_base_model_version",
|
||||
"ss_network_module",
|
||||
"ss_network_dim",
|
||||
"ss_network_alpha",
|
||||
"ss_network_args",
|
||||
]
|
||||
minimum_metadata = {}
|
||||
for key in minimum_keys:
|
||||
for key in train_util.SS_METADATA_MINIMUM_KEYS:
|
||||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
@@ -687,7 +680,9 @@ class NetworkTrainer:
|
||||
init_kwargs = {}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
@@ -709,7 +704,11 @@ class NetworkTrainer:
|
||||
metadata["ss_steps"] = str(steps)
|
||||
metadata["ss_epoch"] = str(epoch_no)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user