diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9345a0d5..925e7d97 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves): noise += noise_perlin # broadcast for each batch return noise / noise.std() # Scaled back to roughly unit variance """ + +def max_norm(state_dict, max_norm_value): + downkeys = [] + upkeys = [] + norms = [] + keys_scaled = 0 + + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down","lora_up")) + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].cuda() + up = state_dict[upkeys[i]].cuda() + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + norm = updown.norm().clamp(min=max_norm_value/2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio **0.5 + if ratio != 1: + keys_scaled +=1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm()*ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms)/len(norms), max(norms) + diff --git a/library/train_util.py b/library/train_util.py index d963537d..46c5c3b2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3638,4 +3638,4 @@ class collater_class: # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] + return examples[0] \ No newline at end of file diff --git a/networks/lora.py b/networks/lora.py index f761cce1..e09910e0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name @@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying + self.dropout = dropout def apply_to(self): self.org_forward = self.org_module.forward @@ -67,7 +68,10 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.dropout: + return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale + else: + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale class LoRAInfModule(LoRAModule): @@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs): return down_lr_weight, mid_lr_weight, up_lr_weight -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, varbose=True, + dropout=dropout, ) if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: @@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module): modules_alpha=None, module_class=LoRAModule, varbose=False, + dropout=None ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module): self.alpha = alpha self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha + self.dropout = dropout + print(f"Neuron dropout: p={self.dropout}") if modules_dim is not None: print(f"create LoRA network from weights") @@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue - lora = module_class(lora_name, child_module, self.multiplier, dim, alpha) + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout) loras.append(lora) return loras, skipped diff --git a/train_network.py b/train_network.py index f8db030c..191e6dd1 100644 --- a/train_network.py +++ b/train_network.py @@ -25,12 +25,16 @@ from library.config_util import ( ) import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm # TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None): logs = {"loss/current": current_loss, "loss/average": avr_loss} + if args.scale_weight_norms: + logs["keys_scaled"] = keys_scaled + logs["average_key_norm"] = mean_norm + logs["max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() @@ -196,13 +200,14 @@ def train(args): if args.dim_from_weights: network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) else: - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs) if network is None: return if hasattr(network, "prepare_network"): network.prepare_network(args) + train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only network.apply_to(text_encoder, unet, train_text_encoder, train_unet) @@ -375,6 +380,8 @@ def train(args): "ss_face_crop_aug_range": args.face_crop_aug_range, "ss_prior_loss_weight": args.prior_loss_weight, "ss_min_snr_gamma": args.min_snr_gamma, + "ss_scale_weight_norms": args.scale_weight_norms, + "ss_dropout": args.dropout, } if use_user_config: @@ -580,6 +587,7 @@ def train(args): network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) @@ -651,6 +659,10 @@ def train(args): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -686,9 +698,12 @@ def train(args): avr_loss = loss_total / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + if args.scale_weight_norms: + progress_bar.set_postfix(**max_mean_logs) + if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) + parser.add_argument( + "--scale_weight_norms", + type=float, + default=None, + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)", + ) + parser.add_argument( + "--dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)", + ) parser.add_argument( "--base_weights", type=str,