mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Dropout and Max Norm Regularization for LoRA training (#545)
* Instantiate max_norm * minor * Move to end of step * argparse * metadata * phrasing * Sqrt ratio and logging * fix logging * Dropout test * Dropout Args * Dropout changed to affect LoRA only --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
|
|||||||
noise += noise_perlin # broadcast for each batch
|
noise += noise_perlin # broadcast for each batch
|
||||||
return noise / noise.std() # Scaled back to roughly unit variance
|
return noise / noise.std() # Scaled back to roughly unit variance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def max_norm(state_dict, max_norm_value):
|
||||||
|
downkeys = []
|
||||||
|
upkeys = []
|
||||||
|
norms = []
|
||||||
|
keys_scaled = 0
|
||||||
|
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "lora_down" in key and "weight" in key:
|
||||||
|
downkeys.append(key)
|
||||||
|
upkeys.append(key.replace("lora_down","lora_up"))
|
||||||
|
for i in range(len(downkeys)):
|
||||||
|
down = state_dict[downkeys[i]].cuda()
|
||||||
|
up = state_dict[upkeys[i]].cuda()
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
norm = updown.norm().clamp(min=max_norm_value/2)
|
||||||
|
desired = torch.clamp(norm, max=max_norm_value)
|
||||||
|
ratio = desired.cpu() / norm.cpu()
|
||||||
|
sqrt_ratio = ratio **0.5
|
||||||
|
if ratio != 1:
|
||||||
|
keys_scaled +=1
|
||||||
|
state_dict[upkeys[i]] *= sqrt_ratio
|
||||||
|
state_dict[downkeys[i]] *= sqrt_ratio
|
||||||
|
scalednorm = updown.norm()*ratio
|
||||||
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
|
return keys_scaled, sum(norms)/len(norms), max(norms)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
||||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
@@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
@@ -67,6 +68,9 @@ class LoRAModule(torch.nn.Module):
|
|||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.dropout:
|
||||||
|
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
|
||||||
|
else:
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
|
||||||
@@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|||||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
if network_alpha is None:
|
if network_alpha is None:
|
||||||
@@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
conv_block_dims=conv_block_dims,
|
conv_block_dims=conv_block_dims,
|
||||||
conv_block_alphas=conv_block_alphas,
|
conv_block_alphas=conv_block_alphas,
|
||||||
varbose=True,
|
varbose=True,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||||
@@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
modules_alpha=None,
|
modules_alpha=None,
|
||||||
module_class=LoRAModule,
|
module_class=LoRAModule,
|
||||||
varbose=False,
|
varbose=False,
|
||||||
|
dropout=None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||||
@@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.conv_lora_dim = conv_lora_dim
|
self.conv_lora_dim = conv_lora_dim
|
||||||
self.conv_alpha = conv_alpha
|
self.conv_alpha = conv_alpha
|
||||||
|
self.dropout = dropout
|
||||||
|
print(f"Neuron dropout: p={self.dropout}")
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
print(f"create LoRA network from weights")
|
||||||
@@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
|
|||||||
@@ -25,12 +25,16 @@ from library.config_util import (
|
|||||||
)
|
)
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
if args.scale_weight_norms:
|
||||||
|
logs["keys_scaled"] = keys_scaled
|
||||||
|
logs["average_key_norm"] = mean_norm
|
||||||
|
logs["max_key_norm"] = maximum_norm
|
||||||
|
|
||||||
lrs = lr_scheduler.get_last_lr()
|
lrs = lr_scheduler.get_last_lr()
|
||||||
|
|
||||||
@@ -196,13 +200,14 @@ def train(args):
|
|||||||
if args.dim_from_weights:
|
if args.dim_from_weights:
|
||||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||||
else:
|
else:
|
||||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(network, "prepare_network"):
|
if hasattr(network, "prepare_network"):
|
||||||
network.prepare_network(args)
|
network.prepare_network(args)
|
||||||
|
|
||||||
|
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = not args.network_train_unet_only
|
train_text_encoder = not args.network_train_unet_only
|
||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
@@ -375,6 +380,8 @@ def train(args):
|
|||||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||||
|
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||||
|
"ss_dropout": args.dropout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
@@ -580,6 +587,7 @@ def train(args):
|
|||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
@@ -652,6 +660,10 @@ def train(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if args.scale_weight_norms:
|
||||||
|
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
|
||||||
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
@@ -686,9 +698,12 @@ def train(args):
|
|||||||
avr_loss = loss_total / len(loss_list)
|
avr_loss = loss_total / len(loss_list)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
if args.scale_weight_norms:
|
||||||
|
progress_bar.set_postfix(**max_mean_logs)
|
||||||
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
@@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale_weight_norms",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dropout",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base_weights",
|
"--base_weights",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user