Improve performance. Add curve for AID probabilities

This commit is contained in:
rockerBOO
2025-04-14 01:56:37 -04:00
parent 956275f295
commit 584ea4ee34
2 changed files with 31 additions and 37 deletions

View File

@@ -1,3 +1,4 @@
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -9,27 +10,25 @@ class AID(nn.Module):
def forward(self, x: Tensor):
if self.training:
# Create separate tensors for positive and negative components
pos_mask = (x > 0).float()
neg_mask = (x <= 0).float()
# Use boolean masks and torch.where for better efficiency
pos_mask = x > 0
pos_vals = x * pos_mask
neg_vals = x * neg_mask
# Apply dropout directly with PyTorch's F.dropout
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p # Scale to maintain expectation
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p) # Scale to maintain expectation
neg_dropped = neg_dropped / (1 - self.p)
# Combine results
return pos_dropped + neg_dropped
else:
# During testing, use modified leaky ReLU with coefficient p
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1
# Simplified test-time behavior
return torch.where(x > 0, self.p * x, (1 - self.p) * (-x))
class AID_GELU(nn.Module):
@@ -40,32 +39,25 @@ class AID_GELU(nn.Module):
def forward(self, x):
# Apply GELU first
gelu_output = self.gelu(x)
x = self.gelu(x)
if self.training:
# Separate positive and negative components using masks
pos_mask = (gelu_output > 0).float()
neg_mask = (gelu_output <= 0).float()
# Create masks once and reuse
pos_mask = x > 0
pos_vals = gelu_output * pos_mask
neg_vals = gelu_output * neg_mask
# Apply dropout with different probabilities
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p)
return pos_dropped + neg_dropped
else:
# Test time behavior
pos_mask = (gelu_output > 0).float()
neg_mask = (gelu_output <= 0).float()
pos_vals = gelu_output * pos_mask
neg_vals = gelu_output * neg_mask
return self.p * pos_vals + (1 - self.p) * neg_vals
# Test time behavior - simplify with direct where operations
return torch.where(x > 0, self.p * x, (1 - self.p) * x)

View File

@@ -30,18 +30,19 @@ logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
# Normalize the position to 0-1 range
normalized_pos = block_id / total_blocks
# Shift the sine curve to only use the first 3/4 of the cycle
# This gives us: start at 0, peak in the middle, end around 0.7
phase_shift = shift * math.pi
sine_value = math.sin(normalized_pos * phase_shift)
# Scale to our desired peak of 0.9
result = peak * sine_value
return result
@@ -123,7 +124,9 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid = AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() # AID activation
self.aid = (
AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity()
) # AID activation
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -173,7 +176,7 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# Activation by Interval-wise Dropout
# Activation by Interval-wise Dropout
lx = self.aid(lx)
# LoRA Gradient-Guided Perturbation Optimization
@@ -863,11 +866,10 @@ class LoRANetwork(torch.nn.Module):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if block_index is not None:
if block_index is not None and aid_dropout is not None:
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
aid_dropout_p = get_point_on_curve(all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS)
aid_dropout_p = get_point_on_curve(
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
if (
is_flux