Simplify AID implementation. AID(B*A) instead of AID(B)*A.

This commit is contained in:
rockerBOO
2025-04-14 15:50:54 -04:00
parent 584ea4ee34
commit 9bc392001c
2 changed files with 66 additions and 50 deletions

View File

@@ -4,60 +4,76 @@ import torch.nn.functional as F
class AID(nn.Module):
def __init__(self, dropout_prob=0.9):
def __init__(self, p=0.9):
super().__init__()
self.p = dropout_prob
self.p = p
def forward(self, x: Tensor):
if self.training:
# Use boolean masks and torch.where for better efficiency
pos_mask = x > 0
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p)
return pos_dropped + neg_dropped
pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
return x * (pos_mask + neg_mask)
else:
# Simplified test-time behavior
return torch.where(x > 0, self.p * x, (1 - self.p) * (-x))
pos_part = (x >= 0).float() * x * self.p
neg_part = (x < 0).float() * x * (1 - self.p)
return pos_part + neg_part
class AID_GELU(nn.Module):
def __init__(self, dropout_prob=0.9, approximate="none"):
super().__init__()
self.p = dropout_prob
self.gelu = nn.GELU(approximate=approximate)
# class AID(nn.Module):
# def __init__(self, dropout_prob=0.9):
# super().__init__()
# self.p = dropout_prob
#
# def forward(self, x: Tensor):
# if self.training:
# # Use boolean masks and torch.where for better efficiency
# pos_mask = x > 0
#
# # Process positive values (keep with probability p)
# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
# if self.p > 0:
# pos_dropped = pos_dropped / self.p
#
# # Process negative values (keep with probability 1-p)
# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
# neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
# if (1 - self.p) > 0:
# neg_dropped = neg_dropped / (1 - self.p)
#
# return pos_dropped + neg_dropped
# else:
# # Simplified test-time behavior
# return torch.where(x > 0, self.p * x, (1 - self.p) * (-x))
def forward(self, x):
# Apply GELU first
x = self.gelu(x)
if self.training:
# Create masks once and reuse
pos_mask = x > 0
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p)
return pos_dropped + neg_dropped
else:
# Test time behavior - simplify with direct where operations
return torch.where(x > 0, self.p * x, (1 - self.p) * x)
# class AID_GELU(nn.Module):
# def __init__(self, p=0.9, approximate="none"):
# super().__init__()
# self.p = p
# self.gelu = nn.GELU(approximate=approximate)
#
# def forward(self, x):
# # Apply GELU first
# x = self.gelu(x)
#
# if self.training:
# # Create masks once and reuse
# pos_mask = x > 0
#
# # Process positive values (keep with probability p)
# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
# if self.p > 0:
# pos_dropped = pos_dropped / self.p
#
# # Process negative values (keep with probability 1-p)
# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
# neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
# if (1 - self.p) > 0:
# neg_dropped = neg_dropped / (1 - self.p)
#
# return pos_dropped + neg_dropped
# else:
# # Test time behavior - simplify with direct where operations
# return torch.where(x > 0, self.p * x, (1 - self.p) * x)

View File

@@ -17,7 +17,7 @@ import numpy as np
import torch
from torch import Tensor
import re
from library.model_utils import AID_GELU
from library.model_utils import AID
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -125,7 +125,7 @@ class LoRAModule(torch.nn.Module):
self.module_dropout = module_dropout
self.aid = (
AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity()
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
) # AID activation
self.ggpo_sigma = ggpo_sigma