mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
10 Commits
sd3
...
88a4442b7e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88a4442b7e | ||
|
|
ac120e68ef | ||
|
|
aefea026a7 | ||
|
|
c2f75f43a4 | ||
|
|
4d005cdf3d | ||
|
|
9bc392001c | ||
|
|
584ea4ee34 | ||
|
|
956275f295 | ||
|
|
61af45ef3c | ||
|
|
dfae3a486c |
18
library/model_utils.py
Normal file
18
library/model_utils.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class AID(nn.Module):
|
||||||
|
def __init__(self, p=0.9):
|
||||||
|
super().__init__()
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
if self.training:
|
||||||
|
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
||||||
|
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
||||||
|
return x * (pos_mask + neg_mask)
|
||||||
|
else:
|
||||||
|
pos_part = (x >= 0) * x * self.p
|
||||||
|
neg_part = (x < 0) * x * (1 - self.p)
|
||||||
|
return pos_part + neg_part
|
||||||
@@ -17,6 +17,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import re
|
import re
|
||||||
|
from library.model_utils import AID
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
@@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19
|
|||||||
NUM_SINGLE_BLOCKS = 38
|
NUM_SINGLE_BLOCKS = 38
|
||||||
|
|
||||||
|
|
||||||
|
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
|
||||||
|
# Normalize the position to 0-1 range
|
||||||
|
normalized_pos = block_id / total_blocks
|
||||||
|
|
||||||
|
# Shift the sine curve to only use the first 3/4 of the cycle
|
||||||
|
# This gives us: start at 0, peak in the middle, end around 0.7
|
||||||
|
phase_shift = shift * math.pi
|
||||||
|
sine_value = math.sin(normalized_pos * phase_shift)
|
||||||
|
|
||||||
|
# Scale to our desired peak of 0.9
|
||||||
|
result = peak * sine_value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
@@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
dropout=None,
|
dropout=None,
|
||||||
rank_dropout=None,
|
rank_dropout=None,
|
||||||
module_dropout=None,
|
module_dropout=None,
|
||||||
|
aid_dropout=None,
|
||||||
split_dims: Optional[List[int]] = None,
|
split_dims: Optional[List[int]] = None,
|
||||||
ggpo_beta: Optional[float] = None,
|
ggpo_beta: Optional[float] = None,
|
||||||
ggpo_sigma: Optional[float] = None,
|
ggpo_sigma: Optional[float] = None,
|
||||||
@@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
self.aid = (
|
||||||
|
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
|
||||||
|
) # AID activation
|
||||||
|
|
||||||
|
if aid_dropout is not None:
|
||||||
|
self.register_buffer("aid_p", torch.tensor(aid_dropout))
|
||||||
|
|
||||||
self.ggpo_sigma = ggpo_sigma
|
self.ggpo_sigma = ggpo_sigma
|
||||||
self.ggpo_beta = ggpo_beta
|
self.ggpo_beta = ggpo_beta
|
||||||
|
|
||||||
@@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
lx = self.lora_up(lx)
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
|
# Activation by Interval-wise Dropout
|
||||||
|
lx = self.aid(lx)
|
||||||
|
|
||||||
# LoRA Gradient-Guided Perturbation Optimization
|
# LoRA Gradient-Guided Perturbation Optimization
|
||||||
if (
|
if (
|
||||||
self.training
|
self.training
|
||||||
@@ -554,6 +581,9 @@ def create_network(
|
|||||||
module_dropout = kwargs.get("module_dropout", None)
|
module_dropout = kwargs.get("module_dropout", None)
|
||||||
if module_dropout is not None:
|
if module_dropout is not None:
|
||||||
module_dropout = float(module_dropout)
|
module_dropout = float(module_dropout)
|
||||||
|
aid_dropout = kwargs.get("aid_dropout", None)
|
||||||
|
if aid_dropout is not None:
|
||||||
|
aid_dropout = float(aid_dropout)
|
||||||
|
|
||||||
# single or double blocks
|
# single or double blocks
|
||||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||||
@@ -630,6 +660,7 @@ def create_network(
|
|||||||
dropout=neuron_dropout,
|
dropout=neuron_dropout,
|
||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
|
aid_dropout=aid_dropout,
|
||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
train_blocks=train_blocks,
|
train_blocks=train_blocks,
|
||||||
@@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dropout: Optional[float] = None,
|
dropout: Optional[float] = None,
|
||||||
rank_dropout: Optional[float] = None,
|
rank_dropout: Optional[float] = None,
|
||||||
module_dropout: Optional[float] = None,
|
module_dropout: Optional[float] = None,
|
||||||
|
aid_dropout: Optional[float] = None,
|
||||||
conv_lora_dim: Optional[int] = None,
|
conv_lora_dim: Optional[int] = None,
|
||||||
conv_alpha: Optional[float] = None,
|
conv_alpha: Optional[float] = None,
|
||||||
module_class: Type[object] = LoRAModule,
|
module_class: Type[object] = LoRAModule,
|
||||||
@@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
self.aid_dropout = aid_dropout
|
||||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||||
self.split_qkv = split_qkv
|
self.split_qkv = split_qkv
|
||||||
self.train_t5xxl = train_t5xxl
|
self.train_t5xxl = train_t5xxl
|
||||||
@@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
|
aid_dropout_p = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
# モジュール指定あり
|
# モジュール指定あり
|
||||||
@@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||||
dim = d # may be 0 for skip
|
dim = d # may be 0 for skip
|
||||||
break
|
break
|
||||||
|
is_double = False
|
||||||
|
if "double" in lora_name:
|
||||||
|
is_double = True
|
||||||
|
is_single = False
|
||||||
|
if "single" in lora_name:
|
||||||
|
is_single = True
|
||||||
|
block_index = None
|
||||||
|
if is_flux and dim and (is_double or is_single):
|
||||||
|
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||||
|
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||||
|
|
||||||
|
if block_index is not None and aid_dropout is not None:
|
||||||
|
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
|
||||||
|
aid_dropout_p = get_point_on_curve(
|
||||||
|
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
is_flux
|
is_flux
|
||||||
@@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
and ("double" in lora_name or "single" in lora_name)
|
and ("double" in lora_name or "single" in lora_name)
|
||||||
):
|
):
|
||||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
|
||||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
|
||||||
if (
|
if (
|
||||||
"double" in lora_name
|
"double" in lora_name
|
||||||
and self.train_double_block_indices is not None
|
and self.train_double_block_indices is not None
|
||||||
@@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
|
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
|
||||||
split_dims=split_dims,
|
split_dims=split_dims,
|
||||||
ggpo_beta=ggpo_beta,
|
ggpo_beta=ggpo_beta,
|
||||||
ggpo_sigma=ggpo_sigma,
|
ggpo_sigma=ggpo_sigma,
|
||||||
@@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||||
|
|
||||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||||
|
aid_p = state_dict.pop(f"{lora_name}.aid_p")
|
||||||
|
|
||||||
# merge down weight
|
# merge down weight
|
||||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||||
@@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||||
|
new_state_dict[f"{lora_name}.aid_p"] = aid_p
|
||||||
|
|
||||||
# print(
|
# print(
|
||||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||||
|
|||||||
137
tests/library/test_model_utils.py
Normal file
137
tests/library/test_model_utils.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from library.model_utils import AID
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_tensor():
|
||||||
|
# Create a tensor with positive and negative values
|
||||||
|
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_aid_forward_train_mode(input_tensor):
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
aid.train()
|
||||||
|
|
||||||
|
# Run several forward passes to test stochastic behavior
|
||||||
|
results = []
|
||||||
|
for _ in range(10):
|
||||||
|
output = aid(input_tensor)
|
||||||
|
results.append(output.detach().clone())
|
||||||
|
|
||||||
|
# Test that outputs vary (stochastic behavior)
|
||||||
|
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
|
||||||
|
assert not all_equal, "All outputs are identical, expected variability in training mode"
|
||||||
|
|
||||||
|
# Test shape preservation
|
||||||
|
assert results[0].shape == input_tensor.shape
|
||||||
|
|
||||||
|
def test_aid_forward_eval_mode(input_tensor):
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
aid.eval()
|
||||||
|
|
||||||
|
output = aid(input_tensor)
|
||||||
|
|
||||||
|
# Test deterministic behavior
|
||||||
|
output2 = aid(input_tensor)
|
||||||
|
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
|
||||||
|
|
||||||
|
# Test correct transformation
|
||||||
|
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
|
||||||
|
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
|
||||||
|
|
||||||
|
def test_aid_gradient_flow(input_tensor):
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
aid.train()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
output = aid(input_tensor)
|
||||||
|
|
||||||
|
# Check gradient flow
|
||||||
|
assert output.requires_grad, "Output lost gradient tracking"
|
||||||
|
|
||||||
|
# Compute loss and backpropagate
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Verify gradients were computed
|
||||||
|
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
|
||||||
|
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
|
||||||
|
|
||||||
|
def test_aid_extreme_p_values():
|
||||||
|
# Test with p=1.0 (only positive values pass through)
|
||||||
|
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||||
|
aid = AID(p=1.0)
|
||||||
|
aid.eval()
|
||||||
|
|
||||||
|
output = aid(x)
|
||||||
|
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
|
||||||
|
assert torch.allclose(output, expected), "Failed with p=1.0"
|
||||||
|
|
||||||
|
# Test with p=0.0 (only negative values pass through)
|
||||||
|
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||||
|
aid = AID(p=0.0)
|
||||||
|
aid.eval()
|
||||||
|
|
||||||
|
output = aid(x)
|
||||||
|
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
|
||||||
|
assert torch.allclose(output, expected), "Failed with p=0.0"
|
||||||
|
|
||||||
|
def test_aid_with_all_positive_values():
|
||||||
|
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
aid.train()
|
||||||
|
|
||||||
|
# Run forward passes and check that only positive values are affected
|
||||||
|
output = aid(x)
|
||||||
|
|
||||||
|
# Backprop should work
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None, "No gradients were recorded for all-positive input"
|
||||||
|
|
||||||
|
def test_aid_with_all_negative_values():
|
||||||
|
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
aid.train()
|
||||||
|
|
||||||
|
# Run forward passes and check that only negative values are affected
|
||||||
|
output = aid(x)
|
||||||
|
|
||||||
|
# Backprop should work
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None, "No gradients were recorded for all-negative input"
|
||||||
|
|
||||||
|
def test_aid_with_zero_values():
|
||||||
|
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
|
||||||
|
# Test training mode
|
||||||
|
aid.train()
|
||||||
|
output = aid(x)
|
||||||
|
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||||
|
|
||||||
|
# Test eval mode
|
||||||
|
aid.eval()
|
||||||
|
output = aid(x)
|
||||||
|
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||||
|
|
||||||
|
def test_aid_integration_with_linear_layer():
|
||||||
|
# Test AID's compatibility with a linear layer
|
||||||
|
linear = nn.Linear(5, 2)
|
||||||
|
aid = AID(p=0.9)
|
||||||
|
|
||||||
|
model = nn.Sequential(linear, aid)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
x = torch.randn(3, 5, requires_grad=True)
|
||||||
|
output = model(x)
|
||||||
|
|
||||||
|
# Check that gradients flow through the whole model
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert linear.weight.grad is not None, "No gradients for linear layer weights"
|
||||||
|
assert x.grad is not None, "No gradients for input tensor"
|
||||||
Reference in New Issue
Block a user