From 584ea4ee341d4a7047f3b7b147c7cfb6a99441d9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Apr 2025 01:56:37 -0400 Subject: [PATCH] Improve performance. Add curve for AID probabilities --- library/model_utils.py | 48 ++++++++++++++++++------------------------ networks/lora_flux.py | 20 ++++++++++-------- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index fb3fc794..03cd73aa 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -1,3 +1,4 @@ +import torch from torch import nn, Tensor import torch.nn.functional as F @@ -9,27 +10,25 @@ class AID(nn.Module): def forward(self, x: Tensor): if self.training: - # Create separate tensors for positive and negative components - pos_mask = (x > 0).float() - neg_mask = (x <= 0).float() + # Use boolean masks and torch.where for better efficiency + pos_mask = x > 0 - pos_vals = x * pos_mask - neg_vals = x * neg_mask - - # Apply dropout directly with PyTorch's F.dropout + # Process positive values (keep with probability p) + pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) if self.p > 0: - pos_dropped = pos_dropped / self.p # Scale to maintain expectation + pos_dropped = pos_dropped / self.p + # Process negative values (keep with probability 1-p) + neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) neg_dropped = F.dropout(neg_vals, p=self.p, training=True) if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) # Scale to maintain expectation + neg_dropped = neg_dropped / (1 - self.p) - # Combine results return pos_dropped + neg_dropped else: - # During testing, use modified leaky ReLU with coefficient p - return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1 + # Simplified test-time behavior + return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) class AID_GELU(nn.Module): @@ -40,32 +39,25 @@ class AID_GELU(nn.Module): def forward(self, x): # Apply GELU first - gelu_output = self.gelu(x) + x = self.gelu(x) if self.training: - # Separate positive and negative components using masks - pos_mask = (gelu_output > 0).float() - neg_mask = (gelu_output <= 0).float() + # Create masks once and reuse + pos_mask = x > 0 - pos_vals = gelu_output * pos_mask - neg_vals = gelu_output * neg_mask - - # Apply dropout with different probabilities + # Process positive values (keep with probability p) + pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) if self.p > 0: pos_dropped = pos_dropped / self.p + # Process negative values (keep with probability 1-p) + neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) neg_dropped = F.dropout(neg_vals, p=self.p, training=True) if (1 - self.p) > 0: neg_dropped = neg_dropped / (1 - self.p) return pos_dropped + neg_dropped else: - # Test time behavior - pos_mask = (gelu_output > 0).float() - neg_mask = (gelu_output <= 0).float() - - pos_vals = gelu_output * pos_mask - neg_vals = gelu_output * neg_mask - - return self.p * pos_vals + (1 - self.p) * neg_vals + # Test time behavior - simplify with direct where operations + return torch.where(x > 0, self.p * x, (1 - self.p) * x) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 620ee14e..ff553bab 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -30,18 +30,19 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 + def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75): # Normalize the position to 0-1 range normalized_pos = block_id / total_blocks - + # Shift the sine curve to only use the first 3/4 of the cycle # This gives us: start at 0, peak in the middle, end around 0.7 phase_shift = shift * math.pi sine_value = math.sin(normalized_pos * phase_shift) - + # Scale to our desired peak of 0.9 result = peak * sine_value - + return result @@ -123,7 +124,9 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.aid = AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() # AID activation + self.aid = ( + AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() + ) # AID activation self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -173,7 +176,7 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - # Activation by Interval-wise Dropout + # Activation by Interval-wise Dropout lx = self.aid(lx) # LoRA Gradient-Guided Perturbation Optimization @@ -863,11 +866,10 @@ class LoRANetwork(torch.nn.Module): # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." block_index = int(lora_name.split("_")[4]) # bit dirty - if block_index is not None: - + if block_index is not None and aid_dropout is not None: all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS - aid_dropout_p = get_point_on_curve(all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS) - + aid_dropout_p = get_point_on_curve( + all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout) if ( is_flux