mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Improve performance. Add curve for AID probabilities
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -9,27 +10,25 @@ class AID(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if self.training:
|
||||
# Create separate tensors for positive and negative components
|
||||
pos_mask = (x > 0).float()
|
||||
neg_mask = (x <= 0).float()
|
||||
# Use boolean masks and torch.where for better efficiency
|
||||
pos_mask = x > 0
|
||||
|
||||
pos_vals = x * pos_mask
|
||||
neg_vals = x * neg_mask
|
||||
|
||||
# Apply dropout directly with PyTorch's F.dropout
|
||||
# Process positive values (keep with probability p)
|
||||
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
|
||||
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
|
||||
if self.p > 0:
|
||||
pos_dropped = pos_dropped / self.p # Scale to maintain expectation
|
||||
pos_dropped = pos_dropped / self.p
|
||||
|
||||
# Process negative values (keep with probability 1-p)
|
||||
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
|
||||
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
|
||||
if (1 - self.p) > 0:
|
||||
neg_dropped = neg_dropped / (1 - self.p) # Scale to maintain expectation
|
||||
neg_dropped = neg_dropped / (1 - self.p)
|
||||
|
||||
# Combine results
|
||||
return pos_dropped + neg_dropped
|
||||
else:
|
||||
# During testing, use modified leaky ReLU with coefficient p
|
||||
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1
|
||||
# Simplified test-time behavior
|
||||
return torch.where(x > 0, self.p * x, (1 - self.p) * (-x))
|
||||
|
||||
|
||||
class AID_GELU(nn.Module):
|
||||
@@ -40,32 +39,25 @@ class AID_GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
# Apply GELU first
|
||||
gelu_output = self.gelu(x)
|
||||
x = self.gelu(x)
|
||||
|
||||
if self.training:
|
||||
# Separate positive and negative components using masks
|
||||
pos_mask = (gelu_output > 0).float()
|
||||
neg_mask = (gelu_output <= 0).float()
|
||||
# Create masks once and reuse
|
||||
pos_mask = x > 0
|
||||
|
||||
pos_vals = gelu_output * pos_mask
|
||||
neg_vals = gelu_output * neg_mask
|
||||
|
||||
# Apply dropout with different probabilities
|
||||
# Process positive values (keep with probability p)
|
||||
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
|
||||
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
|
||||
if self.p > 0:
|
||||
pos_dropped = pos_dropped / self.p
|
||||
|
||||
# Process negative values (keep with probability 1-p)
|
||||
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
|
||||
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
|
||||
if (1 - self.p) > 0:
|
||||
neg_dropped = neg_dropped / (1 - self.p)
|
||||
|
||||
return pos_dropped + neg_dropped
|
||||
else:
|
||||
# Test time behavior
|
||||
pos_mask = (gelu_output > 0).float()
|
||||
neg_mask = (gelu_output <= 0).float()
|
||||
|
||||
pos_vals = gelu_output * pos_mask
|
||||
neg_vals = gelu_output * neg_mask
|
||||
|
||||
return self.p * pos_vals + (1 - self.p) * neg_vals
|
||||
# Test time behavior - simplify with direct where operations
|
||||
return torch.where(x > 0, self.p * x, (1 - self.p) * x)
|
||||
|
||||
@@ -30,18 +30,19 @@ logger = logging.getLogger(__name__)
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
|
||||
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
|
||||
# Normalize the position to 0-1 range
|
||||
normalized_pos = block_id / total_blocks
|
||||
|
||||
|
||||
# Shift the sine curve to only use the first 3/4 of the cycle
|
||||
# This gives us: start at 0, peak in the middle, end around 0.7
|
||||
phase_shift = shift * math.pi
|
||||
sine_value = math.sin(normalized_pos * phase_shift)
|
||||
|
||||
|
||||
# Scale to our desired peak of 0.9
|
||||
result = peak * sine_value
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -123,7 +124,9 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.aid = AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() # AID activation
|
||||
self.aid = (
|
||||
AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity()
|
||||
) # AID activation
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
@@ -173,7 +176,7 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# Activation by Interval-wise Dropout
|
||||
# Activation by Interval-wise Dropout
|
||||
lx = self.aid(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
@@ -863,11 +866,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
|
||||
if block_index is not None:
|
||||
|
||||
if block_index is not None and aid_dropout is not None:
|
||||
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
|
||||
aid_dropout_p = get_point_on_curve(all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS)
|
||||
|
||||
aid_dropout_p = get_point_on_curve(
|
||||
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
|
||||
|
||||
if (
|
||||
is_flux
|
||||
|
||||
Reference in New Issue
Block a user