From a5c38e5d5b2f1e9421ce58fdaef0a62d3733eed1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 Jun 2023 19:32:22 +0900 Subject: [PATCH] fix crashing when max_norm is diabled --- train_network.py | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 191e6dd1..edbc915e 100644 --- a/train_network.py +++ b/train_network.py @@ -25,16 +25,25 @@ from library.config_util import ( ) import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + pyramid_noise_like, + apply_noise_offset, + max_norm, +) # TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None): +def generate_step_logs( + args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None +): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.scale_weight_norms: - logs["keys_scaled"] = keys_scaled - logs["average_key_norm"] = mean_norm - logs["max_key_norm"] = maximum_norm + + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() @@ -151,7 +160,7 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - + # 差分追加学習のためにモデルを読み込む import sys @@ -200,14 +209,15 @@ def train(args): if args.dim_from_weights: network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) else: - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs) + network = network_module.create_network( + 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs + ) if network is None: return if hasattr(network, "prepare_network"): network.prepare_network(args) - train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only network.apply_to(text_encoder, unet, train_text_encoder, train_unet) @@ -587,7 +597,6 @@ def train(args): network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): - current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) @@ -659,10 +668,12 @@ def train(args): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - + if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms) - max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -698,9 +709,9 @@ def train(args): avr_loss = loss_total / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.scale_weight_norms: - progress_bar.set_postfix(**max_mean_logs) + if args.scale_weight_norms: + progress_bar.set_postfix(**max_mean_logs) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) @@ -806,7 +817,7 @@ def setup_parser() -> argparse.ArgumentParser: "--scale_weight_norms", type=float, default=None, - help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)", + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", ) parser.add_argument( "--dropout",