Compare commits

...

10 Commits

Author SHA1 Message Date
Dave Lage
88a4442b7e Merge ac120e68ef into fa53f71ec0 2026-04-02 06:31:26 +00:00
rockerBOO
ac120e68ef Add aid_p buffer to save AID 2025-04-18 02:02:47 -04:00
rockerBOO
aefea026a7 Add AID tests 2025-04-15 15:38:57 -04:00
rockerBOO
c2f75f43a4 Do not convert to float 2025-04-15 04:41:59 -04:00
rockerBOO
4d005cdf3d Remove old code 2025-04-14 15:52:25 -04:00
rockerBOO
9bc392001c Simplify AID implementation. AID(B*A) instead of AID(B)*A. 2025-04-14 15:50:54 -04:00
rockerBOO
584ea4ee34 Improve performance. Add curve for AID probabilities 2025-04-14 01:56:37 -04:00
rockerBOO
956275f295 Add pythonpath to pytest.ini 2025-04-13 21:31:35 -04:00
rockerBOO
61af45ef3c Add AID_GELU. Add dropout curve for AID 2025-04-13 21:30:35 -04:00
rockerBOO
dfae3a486c Add AID activation interval-wise dropout 2025-04-13 13:57:52 -04:00
3 changed files with 207 additions and 2 deletions

18
library/model_utils.py Normal file
View File

@@ -0,0 +1,18 @@
import torch
from torch import nn, Tensor
class AID(nn.Module):
def __init__(self, p=0.9):
super().__init__()
self.p = p
def forward(self, x: Tensor):
if self.training:
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
return x * (pos_mask + neg_mask)
else:
pos_part = (x >= 0) * x * self.p
neg_part = (x < 0) * x * (1 - self.p)
return pos_part + neg_part

View File

@@ -17,6 +17,7 @@ import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
import re import re
from library.model_utils import AID
from library.utils import setup_logging from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38 NUM_SINGLE_BLOCKS = 38
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
# Normalize the position to 0-1 range
normalized_pos = block_id / total_blocks
# Shift the sine curve to only use the first 3/4 of the cycle
# This gives us: start at 0, peak in the middle, end around 0.7
phase_shift = shift * math.pi
sine_value = math.sin(normalized_pos * phase_shift)
# Scale to our desired peak of 0.9
result = peak * sine_value
return result
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module):
dropout=None, dropout=None,
rank_dropout=None, rank_dropout=None,
module_dropout=None, module_dropout=None,
aid_dropout=None,
split_dims: Optional[List[int]] = None, split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None, ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None, ggpo_sigma: Optional[float] = None,
@@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.aid = (
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
) # AID activation
if aid_dropout is not None:
self.register_buffer("aid_p", torch.tensor(aid_dropout))
self.ggpo_sigma = ggpo_sigma self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta self.ggpo_beta = ggpo_beta
@@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx) lx = self.lora_up(lx)
# Activation by Interval-wise Dropout
lx = self.aid(lx)
# LoRA Gradient-Guided Perturbation Optimization # LoRA Gradient-Guided Perturbation Optimization
if ( if (
self.training self.training
@@ -554,6 +581,9 @@ def create_network(
module_dropout = kwargs.get("module_dropout", None) module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None: if module_dropout is not None:
module_dropout = float(module_dropout) module_dropout = float(module_dropout)
aid_dropout = kwargs.get("aid_dropout", None)
if aid_dropout is not None:
aid_dropout = float(aid_dropout)
# single or double blocks # single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
@@ -630,6 +660,7 @@ def create_network(
dropout=neuron_dropout, dropout=neuron_dropout,
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
aid_dropout=aid_dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,
conv_alpha=conv_alpha, conv_alpha=conv_alpha,
train_blocks=train_blocks, train_blocks=train_blocks,
@@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module):
dropout: Optional[float] = None, dropout: Optional[float] = None,
rank_dropout: Optional[float] = None, rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None, module_dropout: Optional[float] = None,
aid_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None, conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None, conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule, module_class: Type[object] = LoRAModule,
@@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout self.dropout = dropout
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.aid_dropout = aid_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all" self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl self.train_t5xxl = train_t5xxl
@@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module):
dim = None dim = None
alpha = None alpha = None
aid_dropout_p = None
if modules_dim is not None: if modules_dim is not None:
# モジュール指定あり # モジュール指定あり
@@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module):
if d is not None and all([id in lora_name for id in identifier[i]]): if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip dim = d # may be 0 for skip
break break
is_double = False
if "double" in lora_name:
is_double = True
is_single = False
if "single" in lora_name:
is_single = True
block_index = None
if is_flux and dim and (is_double or is_single):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if block_index is not None and aid_dropout is not None:
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
aid_dropout_p = get_point_on_curve(
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
if ( if (
is_flux is_flux
@@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module):
) )
and ("double" in lora_name or "single" in lora_name) and ("double" in lora_name or "single" in lora_name)
): ):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if ( if (
"double" in lora_name "double" in lora_name
and self.train_double_block_indices is not None and self.train_double_block_indices is not None
@@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout, dropout=dropout,
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
split_dims=split_dims, split_dims=split_dims,
ggpo_beta=ggpo_beta, ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma, ggpo_sigma=ggpo_sigma,
@@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module):
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha") alpha = state_dict.pop(f"{lora_name}.alpha")
aid_p = state_dict.pop(f"{lora_name}.aid_p")
# merge down weight # merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
@@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module):
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha new_state_dict[f"{lora_name}.alpha"] = alpha
new_state_dict[f"{lora_name}.aid_p"] = aid_p
# print( # print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"

View File

@@ -0,0 +1,137 @@
import pytest
import torch
from library.model_utils import AID
from torch import nn
import torch.nn.functional as F
@pytest.fixture
def input_tensor():
# Create a tensor with positive and negative values
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
return x
def test_aid_forward_train_mode(input_tensor):
aid = AID(p=0.9)
aid.train()
# Run several forward passes to test stochastic behavior
results = []
for _ in range(10):
output = aid(input_tensor)
results.append(output.detach().clone())
# Test that outputs vary (stochastic behavior)
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
assert not all_equal, "All outputs are identical, expected variability in training mode"
# Test shape preservation
assert results[0].shape == input_tensor.shape
def test_aid_forward_eval_mode(input_tensor):
aid = AID(p=0.9)
aid.eval()
output = aid(input_tensor)
# Test deterministic behavior
output2 = aid(input_tensor)
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
# Test correct transformation
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
def test_aid_gradient_flow(input_tensor):
aid = AID(p=0.9)
aid.train()
# Forward pass
output = aid(input_tensor)
# Check gradient flow
assert output.requires_grad, "Output lost gradient tracking"
# Compute loss and backpropagate
loss = output.sum()
loss.backward()
# Verify gradients were computed
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
def test_aid_extreme_p_values():
# Test with p=1.0 (only positive values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=1.0)
aid.eval()
output = aid(x)
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
assert torch.allclose(output, expected), "Failed with p=1.0"
# Test with p=0.0 (only negative values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=0.0)
aid.eval()
output = aid(x)
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
assert torch.allclose(output, expected), "Failed with p=0.0"
def test_aid_with_all_positive_values():
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only positive values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-positive input"
def test_aid_with_all_negative_values():
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only negative values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-negative input"
def test_aid_with_zero_values():
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
aid = AID(p=0.9)
# Test training mode
aid.train()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
# Test eval mode
aid.eval()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
def test_aid_integration_with_linear_layer():
# Test AID's compatibility with a linear layer
linear = nn.Linear(5, 2)
aid = AID(p=0.9)
model = nn.Sequential(linear, aid)
model.train()
x = torch.randn(3, 5, requires_grad=True)
output = model(x)
# Check that gradients flow through the whole model
loss = output.sum()
loss.backward()
assert linear.weight.grad is not None, "No gradients for linear layer weights"
assert x.grad is not None, "No gradients for input tensor"