diff --git a/library/model_utils.py b/library/model_utils.py index 03cd73aa..381c20c9 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -4,60 +4,76 @@ import torch.nn.functional as F class AID(nn.Module): - def __init__(self, dropout_prob=0.9): + def __init__(self, p=0.9): super().__init__() - self.p = dropout_prob + self.p = p def forward(self, x: Tensor): if self.training: - # Use boolean masks and torch.where for better efficiency - pos_mask = x > 0 - - # Process positive values (keep with probability p) - pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) - pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) - if self.p > 0: - pos_dropped = pos_dropped / self.p - - # Process negative values (keep with probability 1-p) - neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) - neg_dropped = F.dropout(neg_vals, p=self.p, training=True) - if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) - - return pos_dropped + neg_dropped + pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + return x * (pos_mask + neg_mask) else: - # Simplified test-time behavior - return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) + pos_part = (x >= 0).float() * x * self.p + neg_part = (x < 0).float() * x * (1 - self.p) + return pos_part + neg_part -class AID_GELU(nn.Module): - def __init__(self, dropout_prob=0.9, approximate="none"): - super().__init__() - self.p = dropout_prob - self.gelu = nn.GELU(approximate=approximate) +# class AID(nn.Module): +# def __init__(self, dropout_prob=0.9): +# super().__init__() +# self.p = dropout_prob +# +# def forward(self, x: Tensor): +# if self.training: +# # Use boolean masks and torch.where for better efficiency +# pos_mask = x > 0 +# +# # Process positive values (keep with probability p) +# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) +# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) +# if self.p > 0: +# pos_dropped = pos_dropped / self.p +# +# # Process negative values (keep with probability 1-p) +# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) +# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) +# if (1 - self.p) > 0: +# neg_dropped = neg_dropped / (1 - self.p) +# +# return pos_dropped + neg_dropped +# else: +# # Simplified test-time behavior +# return torch.where(x > 0, self.p * x, (1 - self.p) * (-x)) - def forward(self, x): - # Apply GELU first - x = self.gelu(x) - if self.training: - # Create masks once and reuse - pos_mask = x > 0 - - # Process positive values (keep with probability p) - pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) - pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) - if self.p > 0: - pos_dropped = pos_dropped / self.p - - # Process negative values (keep with probability 1-p) - neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) - neg_dropped = F.dropout(neg_vals, p=self.p, training=True) - if (1 - self.p) > 0: - neg_dropped = neg_dropped / (1 - self.p) - - return pos_dropped + neg_dropped - else: - # Test time behavior - simplify with direct where operations - return torch.where(x > 0, self.p * x, (1 - self.p) * x) +# class AID_GELU(nn.Module): +# def __init__(self, p=0.9, approximate="none"): +# super().__init__() +# self.p = p +# self.gelu = nn.GELU(approximate=approximate) +# +# def forward(self, x): +# # Apply GELU first +# x = self.gelu(x) +# +# if self.training: +# # Create masks once and reuse +# pos_mask = x > 0 +# +# # Process positive values (keep with probability p) +# pos_vals = torch.where(pos_mask, x, torch.zeros_like(x)) +# pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True) +# if self.p > 0: +# pos_dropped = pos_dropped / self.p +# +# # Process negative values (keep with probability 1-p) +# neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x)) +# neg_dropped = F.dropout(neg_vals, p=self.p, training=True) +# if (1 - self.p) > 0: +# neg_dropped = neg_dropped / (1 - self.p) +# +# return pos_dropped + neg_dropped +# else: +# # Test time behavior - simplify with direct where operations +# return torch.where(x > 0, self.p * x, (1 - self.p) * x) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ff553bab..158bb136 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,7 +17,7 @@ import numpy as np import torch from torch import Tensor import re -from library.model_utils import AID_GELU +from library.model_utils import AID from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -125,7 +125,7 @@ class LoRAModule(torch.nn.Module): self.module_dropout = module_dropout self.aid = ( - AID_GELU(dropout_prob=aid_dropout, approximate="tanh") if aid_dropout is not None else torch.nn.Identity() + AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity() ) # AID activation self.ggpo_sigma = ggpo_sigma