diff --git a/library/model_utils.py b/library/model_utils.py new file mode 100644 index 00000000..9578b310 --- /dev/null +++ b/library/model_utils.py @@ -0,0 +1,26 @@ +from torch import nn +import torch.nn.functional as F +import torch + + +class AID(nn.Module): + def __init__(self, dropout_prob=0.9): + super(AID, self).__init__() + self.p = dropout_prob + self.training = True + + def forward(self, x): + if self.training: + # Generate masks for positive and negative values + pos_mask = torch.bernoulli(torch.full_like(x, self.p)) + neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p)) + + # Apply masks to positive and negative parts + pos_part = F.relu(x) * pos_mask + neg_part = F.relu(-x) * neg_mask * -1 + + return pos_part + neg_part + else: + # During testing, use modified leaky ReLU with coefficient p + return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1 + diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..baa33134 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,6 +17,7 @@ import numpy as np import torch from torch import Tensor import re +from library.model_utils import AID from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -45,6 +46,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + aid_dropout=None, split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, @@ -107,6 +109,8 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid = AID(dropout_prob=aid_dropout) # AID activation + self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -155,6 +159,9 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) + if self.aid_dropout is not None and self.training: + lx = self.aid(lx) + # LoRA Gradient-Guided Perturbation Optimization if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: with torch.no_grad(): @@ -544,6 +551,9 @@ def create_network( module_dropout = kwargs.get("module_dropout", None) if module_dropout is not None: module_dropout = float(module_dropout) + aid_dropout = kwargs.get("aid_dropout", None) + if aid_dropout is not None: + aid_dropout = float(aid_dropout) # single or double blocks train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" @@ -585,6 +595,7 @@ def create_network( dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, @@ -696,6 +707,7 @@ class LoRANetwork(torch.nn.Module): dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, + aid_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, @@ -722,6 +734,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid_dropout = aid_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl @@ -876,6 +889,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout, split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma,