mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Dropout and Max Norm Regularization for LoRA training (#545)
* Instantiate max_norm * minor * Move to end of step * argparse * metadata * phrasing * Sqrt ratio and logging * fix logging * Dropout test * Dropout Args * Dropout changed to affect LoRA only --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -25,12 +25,16 @@ from library.config_util import (
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
if args.scale_weight_norms:
|
||||
logs["keys_scaled"] = keys_scaled
|
||||
logs["average_key_norm"] = mean_norm
|
||||
logs["max_key_norm"] = maximum_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
@@ -196,13 +200,14 @@ def train(args):
|
||||
if args.dim_from_weights:
|
||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
@@ -375,6 +380,8 @@ def train(args):
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||
"ss_dropout": args.dropout,
|
||||
}
|
||||
|
||||
if use_user_config:
|
||||
@@ -580,6 +587,7 @@ def train(args):
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
on_step_start(text_encoder, unet)
|
||||
@@ -651,6 +659,10 @@ def train(args):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -686,9 +698,12 @@ def train(args):
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**max_mean_logs)
|
||||
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
@@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_weight_norms",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user