mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge ac120e68ef into d633b51126
This commit is contained in:
18
library/model_utils.py
Normal file
18
library/model_utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class AID(nn.Module):
|
||||
def __init__(self, p=0.9):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if self.training:
|
||||
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
||||
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
||||
return x * (pos_mask + neg_mask)
|
||||
else:
|
||||
pos_part = (x >= 0) * x * self.p
|
||||
neg_part = (x < 0) * x * (1 - self.p)
|
||||
return pos_part + neg_part
|
||||
@@ -17,6 +17,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.model_utils import AID
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
@@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
|
||||
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
|
||||
# Normalize the position to 0-1 range
|
||||
normalized_pos = block_id / total_blocks
|
||||
|
||||
# Shift the sine curve to only use the first 3/4 of the cycle
|
||||
# This gives us: start at 0, peak in the middle, end around 0.7
|
||||
phase_shift = shift * math.pi
|
||||
sine_value = math.sin(normalized_pos * phase_shift)
|
||||
|
||||
# Scale to our desired peak of 0.9
|
||||
result = peak * sine_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
aid_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
@@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.aid = (
|
||||
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
|
||||
) # AID activation
|
||||
|
||||
if aid_dropout is not None:
|
||||
self.register_buffer("aid_p", torch.tensor(aid_dropout))
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
@@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# Activation by Interval-wise Dropout
|
||||
lx = self.aid(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if (
|
||||
self.training
|
||||
@@ -554,6 +581,9 @@ def create_network(
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
aid_dropout = kwargs.get("aid_dropout", None)
|
||||
if aid_dropout is not None:
|
||||
aid_dropout = float(aid_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
@@ -630,6 +660,7 @@ def create_network(
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
@@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
aid_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
@@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.aid_dropout = aid_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
@@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
aid_dropout_p = None
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
@@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
is_double = False
|
||||
if "double" in lora_name:
|
||||
is_double = True
|
||||
is_single = False
|
||||
if "single" in lora_name:
|
||||
is_single = True
|
||||
block_index = None
|
||||
if is_flux and dim and (is_double or is_single):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
|
||||
if block_index is not None and aid_dropout is not None:
|
||||
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
|
||||
aid_dropout_p = get_point_on_curve(
|
||||
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
|
||||
|
||||
if (
|
||||
is_flux
|
||||
@@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
)
|
||||
and ("double" in lora_name or "single" in lora_name)
|
||||
):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
if (
|
||||
"double" in lora_name
|
||||
and self.train_double_block_indices is not None
|
||||
@@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
@@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||
|
||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||
aid_p = state_dict.pop(f"{lora_name}.aid_p")
|
||||
|
||||
# merge down weight
|
||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||
@@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||
new_state_dict[f"{lora_name}.aid_p"] = aid_p
|
||||
|
||||
# print(
|
||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||
|
||||
137
tests/library/test_model_utils.py
Normal file
137
tests/library/test_model_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
import torch
|
||||
from library.model_utils import AID
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@pytest.fixture
|
||||
def input_tensor():
|
||||
# Create a tensor with positive and negative values
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
return x
|
||||
|
||||
def test_aid_forward_train_mode(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run several forward passes to test stochastic behavior
|
||||
results = []
|
||||
for _ in range(10):
|
||||
output = aid(input_tensor)
|
||||
results.append(output.detach().clone())
|
||||
|
||||
# Test that outputs vary (stochastic behavior)
|
||||
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
|
||||
assert not all_equal, "All outputs are identical, expected variability in training mode"
|
||||
|
||||
# Test shape preservation
|
||||
assert results[0].shape == input_tensor.shape
|
||||
|
||||
def test_aid_forward_eval_mode(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.eval()
|
||||
|
||||
output = aid(input_tensor)
|
||||
|
||||
# Test deterministic behavior
|
||||
output2 = aid(input_tensor)
|
||||
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
|
||||
|
||||
# Test correct transformation
|
||||
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
|
||||
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
|
||||
|
||||
def test_aid_gradient_flow(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Forward pass
|
||||
output = aid(input_tensor)
|
||||
|
||||
# Check gradient flow
|
||||
assert output.requires_grad, "Output lost gradient tracking"
|
||||
|
||||
# Compute loss and backpropagate
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed
|
||||
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
|
||||
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
|
||||
|
||||
def test_aid_extreme_p_values():
|
||||
# Test with p=1.0 (only positive values pass through)
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
aid = AID(p=1.0)
|
||||
aid.eval()
|
||||
|
||||
output = aid(x)
|
||||
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
|
||||
assert torch.allclose(output, expected), "Failed with p=1.0"
|
||||
|
||||
# Test with p=0.0 (only negative values pass through)
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
aid = AID(p=0.0)
|
||||
aid.eval()
|
||||
|
||||
output = aid(x)
|
||||
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
|
||||
assert torch.allclose(output, expected), "Failed with p=0.0"
|
||||
|
||||
def test_aid_with_all_positive_values():
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run forward passes and check that only positive values are affected
|
||||
output = aid(x)
|
||||
|
||||
# Backprop should work
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None, "No gradients were recorded for all-positive input"
|
||||
|
||||
def test_aid_with_all_negative_values():
|
||||
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run forward passes and check that only negative values are affected
|
||||
output = aid(x)
|
||||
|
||||
# Backprop should work
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None, "No gradients were recorded for all-negative input"
|
||||
|
||||
def test_aid_with_zero_values():
|
||||
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
|
||||
# Test training mode
|
||||
aid.train()
|
||||
output = aid(x)
|
||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||
|
||||
# Test eval mode
|
||||
aid.eval()
|
||||
output = aid(x)
|
||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||
|
||||
def test_aid_integration_with_linear_layer():
|
||||
# Test AID's compatibility with a linear layer
|
||||
linear = nn.Linear(5, 2)
|
||||
aid = AID(p=0.9)
|
||||
|
||||
model = nn.Sequential(linear, aid)
|
||||
model.train()
|
||||
|
||||
x = torch.randn(3, 5, requires_grad=True)
|
||||
output = model(x)
|
||||
|
||||
# Check that gradients flow through the whole model
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
assert linear.weight.grad is not None, "No gradients for linear layer weights"
|
||||
assert x.grad is not None, "No gradients for input tensor"
|
||||
Reference in New Issue
Block a user