diff --git a/library/network_utils.py b/library/network_utils.py new file mode 100644 index 00000000..09cc8f31 --- /dev/null +++ b/library/network_utils.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +# LoRA Dropout as a Sparsity Regularizer for Overfitting Control +def lora_dropout_down(down: Tensor, x: Tensor, dropout_prob=0.5): + """ A = A · diag(mA), mA ∼ Bern(1 − p)""" + mask = torch.bernoulli( + torch.ones(down.shape[1], device=down.device) * (1 - dropout_prob) + ) + + # Apply input dimension mask (columns of down-projection) + lx = x @ (down * mask.view(1, -1)).t() + return lx + +def lora_dropout_up(up: Tensor, x: Tensor, dropout_prob=0.5): + """ B = B⊤ · diag(mB )⊤ , mB ∼ Bern(1 − p)""" + mask = torch.bernoulli( + torch.ones(up.shape[0], device=up.device) * (1 - dropout_prob) + ) + + # Apply output dimension mask (rows of up-projection) + lx = x @ (up * mask.view(-1, 1)).t() + return lx diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..554755e1 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -18,7 +18,7 @@ import torch from torch import Tensor import re from library.utils import setup_logging -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.network_utils import lora_dropout_down, lora_dropout_up setup_logging() import logging @@ -45,6 +45,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + lora_dropout=None, split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, @@ -106,6 +107,7 @@ class LoRAModule(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.lora_dropout = lora_dropout self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -132,7 +134,11 @@ class LoRAModule(torch.nn.Module): return org_forwarded if self.split_dims is None: - lx = self.lora_down(x) + # LoRA Dropout as a Sparsity Regularizer for Overfitting Control + if self.lora_dropout is not None and self.training and self.lora_dropout > 0: + lx = lora_dropout_down(self.lora_down.weight, x, dropout_prob=self.lora_dropout) + else: + lx = self.lora_down(x) # normal dropout if self.dropout is not None and self.training: @@ -153,14 +159,26 @@ class LoRAModule(torch.nn.Module): else: scale = self.scale - lx = self.lora_up(lx) + # LoRA Dropout as a Sparsity Regularizer for Overfitting Control + if self.lora_dropout is not None and self.training and self.lora_dropout > 0: + lx = lora_dropout_up(self.lora_up.weight, lx, dropout_prob=self.lora_dropout) + else: + lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + if ( + self.training + and self.ggpo_sigma is not None + and self.ggpo_beta is not None + and self.combined_weight_norms is not None + and self.grad_norms is not None + ): with torch.no_grad(): - perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + ( + self.ggpo_beta * (self.grad_norms**2) + ) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) - perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) perturbation.mul_(perturbation_scale_factor) perturbation_output = x @ perturbation.T # Result: (batch × n) return org_forwarded + (self.multiplier * scale * lx) + perturbation_output @@ -197,24 +215,24 @@ class LoRAModule(torch.nn.Module): # Choose a reasonable sample size n_rows = org_module_weight.shape[0] sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller - + # Sample random indices across all rows indices = torch.randperm(n_rows)[:sample_size] - + # Convert to a supported data type first, then index # Use float32 for indexing operations weights_float32 = org_module_weight.to(dtype=torch.float32) sampled_weights = weights_float32[indices].to(device=self.device) - + # Calculate sampled norms sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) - + # Store the mean norm as our estimate self.org_weight_norm_estimate = sampled_norms.mean() - + # Optional: store standard deviation for confidence intervals self.org_weight_norm_std = sampled_norms.std() - + # Free memory del sampled_weights, weights_float32 @@ -223,37 +241,36 @@ class LoRAModule(torch.nn.Module): # Calculate the true norm (this will be slow but it's just for validation) true_norms = [] chunk_size = 1024 # Process in chunks to avoid OOM - + for i in range(0, org_module_weight.shape[0], chunk_size): end_idx = min(i + chunk_size, org_module_weight.shape[0]) chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) chunk_norms = torch.norm(chunk, dim=1, keepdim=True) true_norms.append(chunk_norms.cpu()) del chunk - + true_norms = torch.cat(true_norms, dim=0) true_mean_norm = true_norms.mean().item() - + # Compare with our estimate estimated_norm = self.org_weight_norm_estimate.item() - + # Calculate error metrics absolute_error = abs(true_mean_norm - estimated_norm) relative_error = absolute_error / true_mean_norm * 100 # as percentage - + if verbose: logger.info(f"True mean norm: {true_mean_norm:.6f}") logger.info(f"Estimated norm: {estimated_norm:.6f}") logger.info(f"Absolute error: {absolute_error:.6f}") logger.info(f"Relative error: {relative_error:.2f}%") - - return { - 'true_mean_norm': true_mean_norm, - 'estimated_norm': estimated_norm, - 'absolute_error': absolute_error, - 'relative_error': relative_error - } + return { + "true_mean_norm": true_mean_norm, + "estimated_norm": estimated_norm, + "absolute_error": absolute_error, + "relative_error": relative_error, + } @torch.no_grad() def update_norms(self): @@ -261,7 +278,7 @@ class LoRAModule(torch.nn.Module): if self.ggpo_beta is None or self.ggpo_sigma is None: return - # only update norms when we are training + # only update norms when we are training if self.training is False: return @@ -269,8 +286,9 @@ class LoRAModule(torch.nn.Module): module_weights.mul(self.scale) self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) - self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) + - torch.sum(module_weights**2, dim=1, keepdim=True)) + self.combined_weight_norms = torch.sqrt( + (self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True) + ) @torch.no_grad() def update_grad_norms(self): @@ -293,7 +311,6 @@ class LoRAModule(torch.nn.Module): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) - @property def device(self): return next(self.parameters()).device @@ -544,6 +561,9 @@ def create_network( module_dropout = kwargs.get("module_dropout", None) if module_dropout is not None: module_dropout = float(module_dropout) + lora_dropout = kwargs.get("lora_dropout", None) + if lora_dropout is not None: + lora_dropout = float(lora_dropout) # single or double blocks train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" @@ -564,7 +584,6 @@ def create_network( if ggpo_sigma is not None: ggpo_sigma = float(ggpo_sigma) - # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -585,6 +604,7 @@ def create_network( dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + lora_dropout=lora_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, @@ -696,6 +716,7 @@ class LoRANetwork(torch.nn.Module): dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, + lora_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, @@ -722,6 +743,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.lora_dropout = lora_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl @@ -757,7 +779,6 @@ class LoRANetwork(torch.nn.Module): if self.train_blocks is not None: logger.info(f"train {self.train_blocks} blocks only") - if train_t5xxl: logger.info(f"train T5XXL as well") @@ -876,6 +897,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + lora_dropout=lora_dropout, split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, @@ -895,7 +917,7 @@ class LoRANetwork(torch.nn.Module): if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False break - logger.info(f"create LoRA for Text Encoder {index+1}:") + logger.info(f"create LoRA for Text Encoder {index + 1}:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") @@ -976,7 +998,6 @@ class LoRANetwork(torch.nn.Module): combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) - def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file