Files
Kohya-ss-sd-scripts/library/model_utils.py

64 lines
2.2 KiB
Python

import torch
from torch import nn, Tensor
import torch.nn.functional as F
class AID(nn.Module):
def __init__(self, dropout_prob=0.9):
super().__init__()
self.p = dropout_prob
def forward(self, x: Tensor):
if self.training:
# Use boolean masks and torch.where for better efficiency
pos_mask = x > 0
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p)
return pos_dropped + neg_dropped
else:
# Simplified test-time behavior
return torch.where(x > 0, self.p * x, (1 - self.p) * (-x))
class AID_GELU(nn.Module):
def __init__(self, dropout_prob=0.9, approximate="none"):
super().__init__()
self.p = dropout_prob
self.gelu = nn.GELU(approximate=approximate)
def forward(self, x):
# Apply GELU first
x = self.gelu(x)
if self.training:
# Create masks once and reuse
pos_mask = x > 0
# Process positive values (keep with probability p)
pos_vals = torch.where(pos_mask, x, torch.zeros_like(x))
pos_dropped = F.dropout(pos_vals, p=1 - self.p, training=True)
if self.p > 0:
pos_dropped = pos_dropped / self.p
# Process negative values (keep with probability 1-p)
neg_vals = torch.where(~pos_mask, x, torch.zeros_like(x))
neg_dropped = F.dropout(neg_vals, p=self.p, training=True)
if (1 - self.p) > 0:
neg_dropped = neg_dropped / (1 - self.p)
return pos_dropped + neg_dropped
else:
# Test time behavior - simplify with direct where operations
return torch.where(x > 0, self.p * x, (1 - self.p) * x)