From dfae3a486ca4d5ed28adcbba07645a0f3831290a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 13 Apr 2025 13:57:52 -0400 Subject: [PATCH 1/9] Add AID activation interval-wise dropout --- library/model_utils.py | 26 ++++++++++++++++++++++++++ networks/lora_flux.py | 14 ++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 library/model_utils.py diff --git a/library/model_utils.py b/library/model_utils.py new file mode 100644 index 00000000..9578b310 --- /dev/null +++ b/library/model_utils.py @@ -0,0 +1,26 @@ +from torch import nn +import torch.nn.functional as F +import torch + + +class AID(nn.Module): + def __init__(self, dropout_prob=0.9): + super(AID, self).__init__() + self.p = dropout_prob + self.training = True + + def forward(self, x): + if self.training: + # Generate masks for positive and negative values + pos_mask = torch.bernoulli(torch.full_like(x, self.p)) + neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p)) + + # Apply masks to positive and negative parts + pos_part = F.relu(x) * pos_mask + neg_part = F.relu(-x) * neg_mask * -1 + + return pos_part + neg_part + else: + # During testing, use modified leaky ReLU with coefficient p + return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1 + diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..baa33134 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,6 +17,7 @@ import numpy as np import torch from torch import Tensor import re +from library.model_utils import AID from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -45,6 +46,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + aid_dropout=None, split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, @@ -107,6 +109,8 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid = AID(dropout_prob=aid_dropout) # AID activation + self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -155,6 +159,9 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) + if self.aid_dropout is not None and self.training: + lx = self.aid(lx) + # LoRA Gradient-Guided Perturbation Optimization if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: with torch.no_grad(): @@ -544,6 +551,9 @@ def create_network( module_dropout = kwargs.get("module_dropout", None) if module_dropout is not None: module_dropout = float(module_dropout) + aid_dropout = kwargs.get("aid_dropout", None) + if aid_dropout is not None: + aid_dropout = float(aid_dropout) # single or double blocks train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" @@ -585,6 +595,7 @@ def create_network( dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, @@ -696,6 +707,7 @@ class LoRANetwork(torch.nn.Module): dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, + aid_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, @@ -722,6 +734,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid_dropout = aid_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl @@ -876,6 +889,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout, split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, From 61af45ef3c8044ce221a1397ba4c89dbe8754249 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 13 Apr 2025 21:30:35 -0400 Subject: [PATCH 2/9] Add AID_GELU. Add dropout curve for AID --- library/model_utils.py | 75 +++++++++++++++++++++++++++++++++--------- networks/lora_flux.py | 43 ++++++++++++++++++++---- 2 files changed, 96 insertions(+), 22 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index 9578b310..fb3fc794 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -1,26 +1,71 @@ -from torch import nn +from torch import nn, Tensor import torch.nn.functional as F -import torch class AID(nn.Module): def __init__(self, dropout_prob=0.9): - super(AID, self).__init__() + super().__init__() self.p = dropout_prob - self.training = True - - def forward(self, x): + + def forward(self, x: Tensor): if self.training: - # Generate masks for positive and negative values - pos_mask = torch.bernoulli(torch.full_like(x, self.p)) - neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p)) - - # Apply masks to positive and negative parts - pos_part = F.relu(x) * pos_mask - neg_part = F.relu(-x) * neg_mask * -1 - - return pos_part + neg_part + # Create separate tensors for positive and negative components + pos_mask = (x > 0).float() + neg_mask = (x <= 0).float() + + pos_vals = x * pos_mask + neg_vals = x * neg_mask + + # Apply dropout directly with PyTorch's F.dropout + pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) + if self.p > 0: + pos_dropped = pos_dropped / self.p # Scale to maintain expectation + + neg_dropped = F.dropout(neg_vals, p=self.p, training=True) + if (1 - self.p) > 0: + neg_dropped = neg_dropped / (1 - self.p) # Scale to maintain expectation + + # Combine results + return pos_dropped + neg_dropped else: # During testing, use modified leaky ReLU with coefficient p return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1 + +class AID_GELU(nn.Module): + def __init__(self, dropout_prob=0.9, approximate="none"): + super().__init__() + self.p = dropout_prob + self.gelu = nn.GELU(approximate=approximate) + + def forward(self, x): + # Apply GELU first + gelu_output = self.gelu(x) + + if self.training: + # Separate positive and negative components using masks + pos_mask = (gelu_output > 0).float() + neg_mask = (gelu_output <= 0).float() + + pos_vals = gelu_output * pos_mask + neg_vals = gelu_output * neg_mask + + # Apply dropout with different probabilities + pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) + if self.p > 0: + pos_dropped = pos_dropped / self.p + + neg_dropped = F.dropout(neg_vals, p=self.p, training=True) + if (1 - self.p) > 0: + neg_dropped = neg_dropped / (1 - self.p) + + return pos_dropped + neg_dropped + else: + # Test time behavior + pos_mask = (gelu_output > 0).float() + neg_mask = (gelu_output <= 0).float() + + pos_vals = gelu_output * pos_mask + neg_vals = gelu_output * neg_mask + + return self.p * pos_vals + (1 - self.p) * neg_vals diff --git a/networks/lora_flux.py b/networks/lora_flux.py index baa33134..620ee14e 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,7 +17,7 @@ import numpy as np import torch from torch import Tensor import re -from library.model_utils import AID +from library.model_utils import AID_GELU from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -30,6 +30,20 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 +def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75): + # Normalize the position to 0-1 range + normalized_pos = block_id / total_blocks + + # Shift the sine curve to only use the first 3/4 of the cycle + # This gives us: start at 0, peak in the middle, end around 0.7 + phase_shift = shift * math.pi + sine_value = math.sin(normalized_pos * phase_shift) + + # Scale to our desired peak of 0.9 + result = peak * sine_value + + return result + class LoRAModule(torch.nn.Module): """ @@ -109,7 +123,7 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.aid = AID(dropout_prob=aid_dropout) # AID activation + self.aid = AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() # AID activation self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -159,8 +173,8 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - if self.aid_dropout is not None and self.training: - lx = self.aid(lx) + # Activation by Interval-wise Dropout + lx = self.aid(lx) # LoRA Gradient-Guided Perturbation Optimization if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: @@ -810,6 +824,7 @@ class LoRANetwork(torch.nn.Module): dim = None alpha = None + aid_dropout_p = None if modules_dim is not None: # モジュール指定あり @@ -837,6 +852,22 @@ class LoRANetwork(torch.nn.Module): if d is not None and all([id in lora_name for id in identifier[i]]): dim = d # may be 0 for skip break + is_double = False + if "double" in lora_name: + is_double = True + is_single = False + if "single" in lora_name: + is_single = True + block_index = None + if is_flux and dim and (is_double or is_single): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + + if block_index is not None: + + all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS + aid_dropout_p = get_point_on_curve(all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS) + if ( is_flux @@ -847,8 +878,6 @@ class LoRANetwork(torch.nn.Module): ) and ("double" in lora_name or "single" in lora_name) ): - # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." - block_index = int(lora_name.split("_")[4]) # bit dirty if ( "double" in lora_name and self.train_double_block_indices is not None @@ -889,7 +918,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, - aid_dropout=aid_dropout, + aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout, split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, From 956275f295655f57e3b53c64ab49559847db6559 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 13 Apr 2025 21:31:35 -0400 Subject: [PATCH 3/9] Add pythonpath to pytest.ini --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . From 584ea4ee341d4a7047f3b7b147c7cfb6a99441d9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Apr 2025 01:56:37 -0400 Subject: [PATCH 4/9] Improve performance. Add curve for AID probabilities --- library/model_utils.py | 48 ++++++++++++++++++------------------------ networks/lora_flux.py | 20 ++++++++++-------- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index fb3fc794..03cd73aa 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -1,3 +1,4 @@ +import torch from torch import nn, Tensor import torch.nn.functional as F @@ -9,27 +10,25 @@ class AID(nn.Module): def forward(self, x: Tensor): if self.training: - # Create separate tensors for positive and negative components - pos_mask = (x > 0).float() - neg_mask = (x <= 0).float() + # Use boolean masks and torch.where for better efficiency + pos_mask = x > 0 - pos_vals = x * pos_mask - neg_vals = x * neg_mask - - # Apply dropout directly with PyTorch's F.dropout + # Process positive values (keep with probability p) + pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) if self.p > 0: - pos_dropped = pos_dropped / self.p # Scale to maintain expectation + pos_dropped = pos_dropped / self.p + # Process negative values (keep with probability 1-p) + neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) neg_dropped = F.dropout(neg_vals, p=self.p, training=True) if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) # Scale to maintain expectation + neg_dropped = neg_dropped / (1 - self.p) - # Combine results return pos_dropped + neg_dropped else: - # During testing, use modified leaky ReLU with coefficient p - return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1 + # Simplified test-time behavior + return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) class AID_GELU(nn.Module): @@ -40,32 +39,25 @@ class AID_GELU(nn.Module): def forward(self, x): # Apply GELU first - gelu_output = self.gelu(x) + x = self.gelu(x) if self.training: - # Separate positive and negative components using masks - pos_mask = (gelu_output > 0).float() - neg_mask = (gelu_output <= 0).float() + # Create masks once and reuse + pos_mask = x > 0 - pos_vals = gelu_output * pos_mask - neg_vals = gelu_output * neg_mask - - # Apply dropout with different probabilities + # Process positive values (keep with probability p) + pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) if self.p > 0: pos_dropped = pos_dropped / self.p + # Process negative values (keep with probability 1-p) + neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) neg_dropped = F.dropout(neg_vals, p=self.p, training=True) if (1 - self.p) > 0: neg_dropped = neg_dropped / (1 - self.p) return pos_dropped + neg_dropped else: - # Test time behavior - pos_mask = (gelu_output > 0).float() - neg_mask = (gelu_output <= 0).float() - - pos_vals = gelu_output * pos_mask - neg_vals = gelu_output * neg_mask - - return self.p * pos_vals + (1 - self.p) * neg_vals + # Test time behavior - simplify with direct where operations + return torch.where(x > 0, self.p * x, (1 - self.p) * x) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 620ee14e..ff553bab 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -30,18 +30,19 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 + def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75): # Normalize the position to 0-1 range normalized_pos = block_id / total_blocks - + # Shift the sine curve to only use the first 3/4 of the cycle # This gives us: start at 0, peak in the middle, end around 0.7 phase_shift = shift * math.pi sine_value = math.sin(normalized_pos * phase_shift) - + # Scale to our desired peak of 0.9 result = peak * sine_value - + return result @@ -123,7 +124,9 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.aid = AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() # AID activation + self.aid = ( + AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() + ) # AID activation self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -173,7 +176,7 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - # Activation by Interval-wise Dropout + # Activation by Interval-wise Dropout lx = self.aid(lx) # LoRA Gradient-Guided Perturbation Optimization @@ -863,11 +866,10 @@ class LoRANetwork(torch.nn.Module): # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." block_index = int(lora_name.split("_")[4]) # bit dirty - if block_index is not None: - + if block_index is not None and aid_dropout is not None: all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS - aid_dropout_p = get_point_on_curve(all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS) - + aid_dropout_p = get_point_on_curve( + all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout) if ( is_flux From 9bc392001c44310537d7d487d21a0c3d3a39cb6c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Apr 2025 15:50:54 -0400 Subject: [PATCH 5/9] Simplify AID implementation. AID(B*A) instead of AID(B)*A. --- library/model_utils.py | 112 +++++++++++++++++++++++------------------ networks/lora_flux.py | 4 +- 2 files changed, 66 insertions(+), 50 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index 03cd73aa..381c20c9 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -4,60 +4,76 @@ import torch.nn.functional as F class AID(nn.Module): - def __init__(self, dropout_prob=0.9): + def __init__(self, p=0.9): super().__init__() - self.p = dropout_prob + self.p = p def forward(self, x: Tensor): if self.training: - # Use boolean masks and torch.where for better efficiency - pos_mask = x > 0 - - # Process positive values (keep with probability p) - pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) - pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) - if self.p > 0: - pos_dropped = pos_dropped / self.p - - # Process negative values (keep with probability 1-p) - neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) - neg_dropped = F.dropout(neg_vals, p=self.p, training=True) - if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) - - return pos_dropped + neg_dropped + pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + return x * (pos_mask + neg_mask) else: - # Simplified test-time behavior - return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) + pos_part = (x >= 0).float() * x * self.p + neg_part = (x < 0).float() * x * (1 - self.p) + return pos_part + neg_part -class AID_GELU(nn.Module): - def __init__(self, dropout_prob=0.9, approximate="none"): - super().__init__() - self.p = dropout_prob - self.gelu = nn.GELU(approximate=approximate) +# class AID(nn.Module): +# def __init__(self, dropout_prob=0.9): +# super().__init__() +# self.p = dropout_prob +# +# def forward(self, x: Tensor): +# if self.training: +# # Use boolean masks and torch.where for better efficiency +# pos_mask = x > 0 +# +# # Process positive values (keep with probability p) +# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) +# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) +# if self.p > 0: +# pos_dropped = pos_dropped / self.p +# +# # Process negative values (keep with probability 1-p) +# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) +# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) +# if (1 - self.p) > 0: +# neg_dropped = neg_dropped / (1 - self.p) +# +# return pos_dropped + neg_dropped +# else: +# # Simplified test-time behavior +# return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) - def forward(self, x): - # Apply GELU first - x = self.gelu(x) - if self.training: - # Create masks once and reuse - pos_mask = x > 0 - - # Process positive values (keep with probability p) - pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) - pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) - if self.p > 0: - pos_dropped = pos_dropped / self.p - - # Process negative values (keep with probability 1-p) - neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) - neg_dropped = F.dropout(neg_vals, p=self.p, training=True) - if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) - - return pos_dropped + neg_dropped - else: - # Test time behavior - simplify with direct where operations - return torch.where(x > 0, self.p * x, (1 - self.p) * x) +# class AID_GELU(nn.Module): +# def __init__(self, p=0.9, approximate="none"): +# super().__init__() +# self.p = p +# self.gelu = nn.GELU(approximate=approximate) +# +# def forward(self, x): +# # Apply GELU first +# x = self.gelu(x) +# +# if self.training: +# # Create masks once and reuse +# pos_mask = x > 0 +# +# # Process positive values (keep with probability p) +# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) +# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) +# if self.p > 0: +# pos_dropped = pos_dropped / self.p +# +# # Process negative values (keep with probability 1-p) +# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) +# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) +# if (1 - self.p) > 0: +# neg_dropped = neg_dropped / (1 - self.p) +# +# return pos_dropped + neg_dropped +# else: +# # Test time behavior - simplify with direct where operations +# return torch.where(x > 0, self.p * x, (1 - self.p) * x) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ff553bab..158bb136 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,7 +17,7 @@ import numpy as np import torch from torch import Tensor import re -from library.model_utils import AID_GELU +from library.model_utils import AID from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -125,7 +125,7 @@ class LoRAModule(torch.nn.Module): self.module_dropout = module_dropout self.aid = ( - AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() + AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity() ) # AID activation self.ggpo_sigma = ggpo_sigma From 4d005cdf3d7b33629ec18d6ff0812a0f2ca05122 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Apr 2025 15:52:25 -0400 Subject: [PATCH 6/9] Remove old code --- library/model_utils.py | 61 ------------------------------------------ 1 file changed, 61 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index 381c20c9..3494e6f4 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -1,6 +1,5 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F class AID(nn.Module): @@ -17,63 +16,3 @@ class AID(nn.Module): pos_part = (x >= 0).float() * x * self.p neg_part = (x < 0).float() * x * (1 - self.p) return pos_part + neg_part - - -# class AID(nn.Module): -# def __init__(self, dropout_prob=0.9): -# super().__init__() -# self.p = dropout_prob -# -# def forward(self, x: Tensor): -# if self.training: -# # Use boolean masks and torch.where for better efficiency -# pos_mask = x > 0 -# -# # Process positive values (keep with probability p) -# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) -# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) -# if self.p > 0: -# pos_dropped = pos_dropped / self.p -# -# # Process negative values (keep with probability 1-p) -# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) -# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) -# if (1 - self.p) > 0: -# neg_dropped = neg_dropped / (1 - self.p) -# -# return pos_dropped + neg_dropped -# else: -# # Simplified test-time behavior -# return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) - - -# class AID_GELU(nn.Module): -# def __init__(self, p=0.9, approximate="none"): -# super().__init__() -# self.p = p -# self.gelu = nn.GELU(approximate=approximate) -# -# def forward(self, x): -# # Apply GELU first -# x = self.gelu(x) -# -# if self.training: -# # Create masks once and reuse -# pos_mask = x > 0 -# -# # Process positive values (keep with probability p) -# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) -# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) -# if self.p > 0: -# pos_dropped = pos_dropped / self.p -# -# # Process negative values (keep with probability 1-p) -# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) -# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) -# if (1 - self.p) > 0: -# neg_dropped = neg_dropped / (1 - self.p) -# -# return pos_dropped + neg_dropped -# else: -# # Test time behavior - simplify with direct where operations -# return torch.where(x > 0, self.p * x, (1 - self.p) * x) From c2f75f43a41f9d1abefe90e5f458d262be37ef53 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 15 Apr 2025 04:41:59 -0400 Subject: [PATCH 7/9] Do not convert to float --- library/model_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index 3494e6f4..87c2d04a 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -9,10 +9,10 @@ class AID(nn.Module): def forward(self, x: Tensor): if self.training: - pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p) - neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) return x * (pos_mask + neg_mask) else: - pos_part = (x >= 0).float() * x * self.p - neg_part = (x < 0).float() * x * (1 - self.p) + pos_part = (x >= 0) * x * self.p + neg_part = (x < 0) * x * (1 - self.p) return pos_part + neg_part From aefea026a7a795a5ed9daf47e0e5ea727bc8c8f6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 15 Apr 2025 15:38:57 -0400 Subject: [PATCH 8/9] Add AID tests --- tests/library/test_model_utils.py | 137 ++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 tests/library/test_model_utils.py diff --git a/tests/library/test_model_utils.py b/tests/library/test_model_utils.py new file mode 100644 index 00000000..c432266e --- /dev/null +++ b/tests/library/test_model_utils.py @@ -0,0 +1,137 @@ +import pytest +import torch +from library.model_utils import AID +from torch import nn +import torch.nn.functional as F + +@pytest.fixture +def input_tensor(): + # Create a tensor with positive and negative values + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + return x + +def test_aid_forward_train_mode(input_tensor): + aid = AID(p=0.9) + aid.train() + + # Run several forward passes to test stochastic behavior + results = [] + for _ in range(10): + output = aid(input_tensor) + results.append(output.detach().clone()) + + # Test that outputs vary (stochastic behavior) + all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10)) + assert not all_equal, "All outputs are identical, expected variability in training mode" + + # Test shape preservation + assert results[0].shape == input_tensor.shape + +def test_aid_forward_eval_mode(input_tensor): + aid = AID(p=0.9) + aid.eval() + + output = aid(input_tensor) + + # Test deterministic behavior + output2 = aid(input_tensor) + assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode" + + # Test correct transformation + expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1 + assert torch.allclose(output, expected), "Incorrect evaluation mode transformation" + +def test_aid_gradient_flow(input_tensor): + aid = AID(p=0.9) + aid.train() + + # Forward pass + output = aid(input_tensor) + + # Check gradient flow + assert output.requires_grad, "Output lost gradient tracking" + + # Compute loss and backpropagate + loss = output.sum() + loss.backward() + + # Verify gradients were computed + assert input_tensor.grad is not None, "No gradients were recorded for input tensor" + assert torch.any(input_tensor.grad != 0), "Gradients are all zeros" + +def test_aid_extreme_p_values(): + # Test with p=1.0 (only positive values pass through) + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + aid = AID(p=1.0) + aid.eval() + + output = aid(x) + expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) + assert torch.allclose(output, expected), "Failed with p=1.0" + + # Test with p=0.0 (only negative values pass through) + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + aid = AID(p=0.0) + aid.eval() + + output = aid(x) + expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0]) + assert torch.allclose(output, expected), "Failed with p=0.0" + +def test_aid_with_all_positive_values(): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + aid = AID(p=0.9) + aid.train() + + # Run forward passes and check that only positive values are affected + output = aid(x) + + # Backprop should work + loss = output.sum() + loss.backward() + assert x.grad is not None, "No gradients were recorded for all-positive input" + +def test_aid_with_all_negative_values(): + x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True) + aid = AID(p=0.9) + aid.train() + + # Run forward passes and check that only negative values are affected + output = aid(x) + + # Backprop should work + loss = output.sum() + loss.backward() + assert x.grad is not None, "No gradients were recorded for all-negative input" + +def test_aid_with_zero_values(): + x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True) + aid = AID(p=0.9) + + # Test training mode + aid.train() + output = aid(x) + assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input" + + # Test eval mode + aid.eval() + output = aid(x) + assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input" + +def test_aid_integration_with_linear_layer(): + # Test AID's compatibility with a linear layer + linear = nn.Linear(5, 2) + aid = AID(p=0.9) + + model = nn.Sequential(linear, aid) + model.train() + + x = torch.randn(3, 5, requires_grad=True) + output = model(x) + + # Check that gradients flow through the whole model + loss = output.sum() + loss.backward() + + assert linear.weight.grad is not None, "No gradients for linear layer weights" + assert x.grad is not None, "No gradients for input tensor" From ac120e68efd483e26a54f09956387e01a3e6aefa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 18 Apr 2025 02:02:47 -0400 Subject: [PATCH 9/9] Add aid_p buffer to save AID --- networks/lora_flux.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 158bb136..a027c70f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -128,6 +128,9 @@ class LoRAModule(torch.nn.Module): AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity() ) # AID activation + if aid_dropout is not None: + self.register_buffer("aid_p", torch.tensor(aid_dropout)) + self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -1115,6 +1118,7 @@ class LoRANetwork(torch.nn.Module): up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] alpha = state_dict.pop(f"{lora_name}.alpha") + aid_p = state_dict.pop(f"{lora_name}.aid_p") # merge down weight down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) @@ -1130,6 +1134,7 @@ class LoRANetwork(torch.nn.Module): new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight new_state_dict[f"{lora_name}.alpha"] = alpha + new_state_dict[f"{lora_name}.aid_p"] = aid_p # print( # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"