Dropout and Max Norm Regularization for LoRA training (#545)

* Instantiate max_norm

* minor

* Move to end of step

* argparse

* metadata

* phrasing

* Sqrt ratio and logging

* fix logging

* Dropout test

* Dropout Args

* Dropout changed to affect LoRA only

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
AI-Casanova
2023-06-01 00:58:38 -05:00
committed by GitHub
parent 5931948adb
commit 9c7237157d
4 changed files with 77 additions and 9 deletions

View File

@@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
noise += noise_perlin # broadcast for each batch noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance return noise / noise.std() # Scaled back to roughly unit variance
""" """
def max_norm(state_dict, max_norm_value):
downkeys = []
upkeys = []
norms = []
keys_scaled = 0
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down","lora_up"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].cuda()
up = state_dict[upkeys[i]].cuda()
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
norm = updown.norm().clamp(min=max_norm_value/2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio **0.5
if ratio != 1:
keys_scaled +=1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm()*ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms)/len(norms), max(norms)

View File

@@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
"""if alpha == 0 or None, alpha is rank (no scaling).""" """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
@@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.dropout = dropout
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
@@ -67,7 +68,10 @@ class LoRAModule(torch.nn.Module):
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
@@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
if network_alpha is None: if network_alpha is None:
@@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims, conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas, conv_block_alphas=conv_block_alphas,
varbose=True, varbose=True,
dropout=dropout,
) )
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None, modules_alpha=None,
module_class=LoRAModule, module_class=LoRAModule,
varbose=False, varbose=False,
dropout=None
) -> None: ) -> None:
""" """
LoRA network: すごく引数が多いが、パターンは以下の通り LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha self.alpha = alpha
self.conv_lora_dim = conv_lora_dim self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
@@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name) skipped.append(lora_name)
continue continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha) lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
loras.append(lora) loras.append(lora)
return loras, skipped return loras, skipped

View File

@@ -25,12 +25,16 @@ from library.config_util import (
) )
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
logs = {"loss/current": current_loss, "loss/average": avr_loss} logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.scale_weight_norms:
logs["keys_scaled"] = keys_scaled
logs["average_key_norm"] = mean_norm
logs["max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr() lrs = lr_scheduler.get_last_lr()
@@ -196,13 +200,14 @@ def train(args):
if args.dim_from_weights: if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else: else:
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
if network is None: if network is None:
return return
if hasattr(network, "prepare_network"): if hasattr(network, "prepare_network"):
network.prepare_network(args) network.prepare_network(args)
train_unet = not args.network_train_text_encoder_only train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet) network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
@@ -375,6 +380,8 @@ def train(args):
"ss_face_crop_aug_range": args.face_crop_aug_range, "ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight, "ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma, "ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_dropout": args.dropout,
} }
if use_user_config: if use_user_config:
@@ -580,6 +587,7 @@ def train(args):
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
on_step_start(text_encoder, unet) on_step_start(text_encoder, unet)
@@ -652,6 +660,10 @@ def train(args):
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
progress_bar.update(1) progress_bar.update(1)
@@ -686,9 +698,12 @@ def train(args):
avr_loss = loss_total / len(loss_list) avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**max_mean_logs)
if args.logging_dir is not None: if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
@@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
) )
parser.add_argument(
"--scale_weight_norms",
type=float,
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
)
parser.add_argument(
"--dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
)
parser.add_argument( parser.add_argument(
"--base_weights", "--base_weights",
type=str, type=str,