mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix crashing when max_norm is diabled
This commit is contained in:
@@ -25,16 +25,25 @@ from library.config_util import (
|
|||||||
)
|
)
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm
|
from library.custom_train_functions import (
|
||||||
|
apply_snr_weight,
|
||||||
|
get_weighted_text_embeddings,
|
||||||
|
pyramid_noise_like,
|
||||||
|
apply_noise_offset,
|
||||||
|
max_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
|
def generate_step_logs(
|
||||||
|
args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
|
||||||
|
):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
if args.scale_weight_norms:
|
|
||||||
logs["keys_scaled"] = keys_scaled
|
if keys_scaled is not None:
|
||||||
logs["average_key_norm"] = mean_norm
|
logs["max_norm/keys_scaled"] = keys_scaled
|
||||||
logs["max_key_norm"] = maximum_norm
|
logs["max_norm/average_key_norm"] = mean_norm
|
||||||
|
logs["max_norm/max_key_norm"] = maximum_norm
|
||||||
|
|
||||||
lrs = lr_scheduler.get_last_lr()
|
lrs = lr_scheduler.get_last_lr()
|
||||||
|
|
||||||
@@ -151,7 +160,7 @@ def train(args):
|
|||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|
||||||
# 差分追加学習のためにモデルを読み込む
|
# 差分追加学習のためにモデルを読み込む
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -200,14 +209,15 @@ def train(args):
|
|||||||
if args.dim_from_weights:
|
if args.dim_from_weights:
|
||||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||||
else:
|
else:
|
||||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
|
network = network_module.create_network(
|
||||||
|
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs
|
||||||
|
)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(network, "prepare_network"):
|
if hasattr(network, "prepare_network"):
|
||||||
network.prepare_network(args)
|
network.prepare_network(args)
|
||||||
|
|
||||||
|
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = not args.network_train_unet_only
|
train_text_encoder = not args.network_train_unet_only
|
||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
@@ -587,7 +597,6 @@ def train(args):
|
|||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
@@ -659,10 +668,12 @@ def train(args):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
|
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
|
||||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
|
else:
|
||||||
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@@ -698,9 +709,9 @@ def train(args):
|
|||||||
avr_loss = loss_total / len(loss_list)
|
avr_loss = loss_total / len(loss_list)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
if args.scale_weight_norms:
|
|
||||||
progress_bar.set_postfix(**max_mean_logs)
|
|
||||||
|
|
||||||
|
if args.scale_weight_norms:
|
||||||
|
progress_bar.set_postfix(**max_mean_logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||||
@@ -806,7 +817,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--scale_weight_norms",
|
"--scale_weight_norms",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
|
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dropout",
|
"--dropout",
|
||||||
|
|||||||
Reference in New Issue
Block a user