Dropout and Max Norm Regularization for LoRA training (#545)

* Instantiate max_norm

* minor

* Move to end of step

* argparse

* metadata

* phrasing

* Sqrt ratio and logging

* fix logging

* Dropout test

* Dropout Args

* Dropout changed to affect LoRA only

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
AI-Casanova
2023-06-01 00:58:38 -05:00
committed by GitHub
parent 5931948adb
commit 9c7237157d
4 changed files with 77 additions and 9 deletions

View File

@@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""
def max_norm(state_dict, max_norm_value):
downkeys = []
upkeys = []
norms = []
keys_scaled = 0
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down","lora_up"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].cuda()
up = state_dict[upkeys[i]].cuda()
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
norm = updown.norm().clamp(min=max_norm_value/2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio **0.5
if ratio != 1:
keys_scaled +=1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm()*ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms)/len(norms), max(norms)

View File

@@ -3638,4 +3638,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]

View File

@@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
@@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -67,7 +68,10 @@ class LoRAModule(torch.nn.Module):
del self.org_module
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
@@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
dropout=dropout,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
dropout=None
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name)
continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
loras.append(lora)
return loras, skipped

View File

@@ -25,12 +25,16 @@ from library.config_util import (
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.scale_weight_norms:
logs["keys_scaled"] = keys_scaled
logs["average_key_norm"] = mean_norm
logs["max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr()
@@ -196,13 +200,14 @@ def train(args):
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
if network is None:
return
if hasattr(network, "prepare_network"):
network.prepare_network(args)
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
@@ -375,6 +380,8 @@ def train(args):
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_dropout": args.dropout,
}
if use_user_config:
@@ -580,6 +587,7 @@ def train(args):
network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
@@ -651,6 +659,10 @@ def train(args):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -686,9 +698,12 @@ def train(args):
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**max_mean_logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
@@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--scale_weight_norms",
type=float,
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
)
parser.add_argument(
"--dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
)
parser.add_argument(
"--base_weights",
type=str,