mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Compare commits
11 Commits
fb8e97fb71
...
88a4442b7e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88a4442b7e | ||
|
|
fa53f71ec0 | ||
|
|
ac120e68ef | ||
|
|
aefea026a7 | ||
|
|
c2f75f43a4 | ||
|
|
4d005cdf3d | ||
|
|
9bc392001c | ||
|
|
584ea4ee34 | ||
|
|
956275f295 | ||
|
|
61af45ef3c | ||
|
|
dfae3a486c |
@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
|
||||
@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
|
||||
@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
use_fp32: bool = False,
|
||||
):
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
@@ -863,11 +863,11 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
@@ -959,6 +959,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -972,6 +973,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -994,6 +996,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
|
||||
|
||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
|
||||
18
library/model_utils.py
Normal file
18
library/model_utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class AID(nn.Module):
|
||||
def __init__(self, p=0.9):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if self.training:
|
||||
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
||||
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
||||
return x * (pos_mask + neg_mask)
|
||||
else:
|
||||
pos_part = (x >= 0) * x * self.p
|
||||
neg_part = (x < 0) * x * (1 - self.p)
|
||||
return pos_part + neg_part
|
||||
@@ -17,6 +17,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.model_utils import AID
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
@@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
|
||||
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
|
||||
# Normalize the position to 0-1 range
|
||||
normalized_pos = block_id / total_blocks
|
||||
|
||||
# Shift the sine curve to only use the first 3/4 of the cycle
|
||||
# This gives us: start at 0, peak in the middle, end around 0.7
|
||||
phase_shift = shift * math.pi
|
||||
sine_value = math.sin(normalized_pos * phase_shift)
|
||||
|
||||
# Scale to our desired peak of 0.9
|
||||
result = peak * sine_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
aid_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
@@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.aid = (
|
||||
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
|
||||
) # AID activation
|
||||
|
||||
if aid_dropout is not None:
|
||||
self.register_buffer("aid_p", torch.tensor(aid_dropout))
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
@@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# Activation by Interval-wise Dropout
|
||||
lx = self.aid(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if (
|
||||
self.training
|
||||
@@ -554,6 +581,9 @@ def create_network(
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
aid_dropout = kwargs.get("aid_dropout", None)
|
||||
if aid_dropout is not None:
|
||||
aid_dropout = float(aid_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
@@ -630,6 +660,7 @@ def create_network(
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
@@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
aid_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
@@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.aid_dropout = aid_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
@@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
aid_dropout_p = None
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
@@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
is_double = False
|
||||
if "double" in lora_name:
|
||||
is_double = True
|
||||
is_single = False
|
||||
if "single" in lora_name:
|
||||
is_single = True
|
||||
block_index = None
|
||||
if is_flux and dim and (is_double or is_single):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
|
||||
if block_index is not None and aid_dropout is not None:
|
||||
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
|
||||
aid_dropout_p = get_point_on_curve(
|
||||
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
|
||||
|
||||
if (
|
||||
is_flux
|
||||
@@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
)
|
||||
and ("double" in lora_name or "single" in lora_name)
|
||||
):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
if (
|
||||
"double" in lora_name
|
||||
and self.train_double_block_indices is not None
|
||||
@@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
@@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||
|
||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||
aid_p = state_dict.pop(f"{lora_name}.aid_p")
|
||||
|
||||
# merge down weight
|
||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||
@@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||
new_state_dict[f"{lora_name}.aid_p"] = aid_p
|
||||
|
||||
# print(
|
||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||
|
||||
137
tests/library/test_model_utils.py
Normal file
137
tests/library/test_model_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
import torch
|
||||
from library.model_utils import AID
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@pytest.fixture
|
||||
def input_tensor():
|
||||
# Create a tensor with positive and negative values
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
return x
|
||||
|
||||
def test_aid_forward_train_mode(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run several forward passes to test stochastic behavior
|
||||
results = []
|
||||
for _ in range(10):
|
||||
output = aid(input_tensor)
|
||||
results.append(output.detach().clone())
|
||||
|
||||
# Test that outputs vary (stochastic behavior)
|
||||
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
|
||||
assert not all_equal, "All outputs are identical, expected variability in training mode"
|
||||
|
||||
# Test shape preservation
|
||||
assert results[0].shape == input_tensor.shape
|
||||
|
||||
def test_aid_forward_eval_mode(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.eval()
|
||||
|
||||
output = aid(input_tensor)
|
||||
|
||||
# Test deterministic behavior
|
||||
output2 = aid(input_tensor)
|
||||
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
|
||||
|
||||
# Test correct transformation
|
||||
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
|
||||
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
|
||||
|
||||
def test_aid_gradient_flow(input_tensor):
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Forward pass
|
||||
output = aid(input_tensor)
|
||||
|
||||
# Check gradient flow
|
||||
assert output.requires_grad, "Output lost gradient tracking"
|
||||
|
||||
# Compute loss and backpropagate
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed
|
||||
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
|
||||
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
|
||||
|
||||
def test_aid_extreme_p_values():
|
||||
# Test with p=1.0 (only positive values pass through)
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
aid = AID(p=1.0)
|
||||
aid.eval()
|
||||
|
||||
output = aid(x)
|
||||
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
|
||||
assert torch.allclose(output, expected), "Failed with p=1.0"
|
||||
|
||||
# Test with p=0.0 (only negative values pass through)
|
||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
||||
aid = AID(p=0.0)
|
||||
aid.eval()
|
||||
|
||||
output = aid(x)
|
||||
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
|
||||
assert torch.allclose(output, expected), "Failed with p=0.0"
|
||||
|
||||
def test_aid_with_all_positive_values():
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run forward passes and check that only positive values are affected
|
||||
output = aid(x)
|
||||
|
||||
# Backprop should work
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None, "No gradients were recorded for all-positive input"
|
||||
|
||||
def test_aid_with_all_negative_values():
|
||||
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
aid.train()
|
||||
|
||||
# Run forward passes and check that only negative values are affected
|
||||
output = aid(x)
|
||||
|
||||
# Backprop should work
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None, "No gradients were recorded for all-negative input"
|
||||
|
||||
def test_aid_with_zero_values():
|
||||
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
|
||||
aid = AID(p=0.9)
|
||||
|
||||
# Test training mode
|
||||
aid.train()
|
||||
output = aid(x)
|
||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||
|
||||
# Test eval mode
|
||||
aid.eval()
|
||||
output = aid(x)
|
||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
||||
|
||||
def test_aid_integration_with_linear_layer():
|
||||
# Test AID's compatibility with a linear layer
|
||||
linear = nn.Linear(5, 2)
|
||||
aid = AID(p=0.9)
|
||||
|
||||
model = nn.Sequential(linear, aid)
|
||||
model.train()
|
||||
|
||||
x = torch.randn(3, 5, requires_grad=True)
|
||||
output = model(x)
|
||||
|
||||
# Check that gradients flow through the whole model
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
assert linear.weight.grad is not None, "No gradients for linear layer weights"
|
||||
assert x.grad is not None, "No gradients for input tensor"
|
||||
Reference in New Issue
Block a user