mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
9 Commits
dev
...
c9eb717cc5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9eb717cc5 | ||
|
|
ee282be91f | ||
|
|
ee922596ba | ||
|
|
ba467e61be | ||
|
|
6b810499a0 | ||
|
|
47a0a9fa9f | ||
|
|
30f479faa6 | ||
|
|
8cdfb2020c | ||
|
|
7ef68b5dc6 |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 git+https://github.com/rockerBOO/ivon@gradient-accumulation
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
@@ -50,12 +50,6 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
|
||||
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
|
||||
@@ -47,12 +47,6 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
|
||||
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
|
||||
@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
use_fp32: bool = False,
|
||||
):
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
@@ -863,11 +863,11 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
@@ -959,7 +959,6 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -973,7 +972,6 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -996,7 +994,6 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1010,7 +1007,6 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1022,7 +1018,6 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1343,19 +1338,16 @@ class Anima(nn.Module):
|
||||
|
||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from packaging import version
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
has_ipex = True
|
||||
@@ -9,7 +8,7 @@ except Exception:
|
||||
has_ipex = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
torch_version = float(torch.__version__[:3])
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
@@ -57,6 +56,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
@@ -64,12 +64,14 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if torch_version < version.parse("2.3"):
|
||||
if torch_version < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
@@ -112,22 +114,17 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
if torch_version < version.parse("2.5"):
|
||||
if torch_version < 2.5:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
|
||||
if torch_version < version.parse("2.7"):
|
||||
if torch_version < 2.7:
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.List = torch.xpu.List
|
||||
|
||||
if torch_version < version.parse("2.11"):
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
|
||||
|
||||
# Memory:
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
@@ -163,7 +160,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# C
|
||||
if torch_version < version.parse("2.3"):
|
||||
if torch_version < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
|
||||
128
library/network_utils.py
Normal file
128
library/network_utils.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_sample_params(optimizer):
|
||||
"""
|
||||
Returns parameter sampling context for IVON optimizers, otherwise returns no-op context.
|
||||
|
||||
pip install ivon-opt
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer instance.
|
||||
|
||||
Returns:
|
||||
Context manager for parameter sampling if optimizer supports it, otherwise nullcontext().
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1):
|
||||
"""
|
||||
Context manager that monkey patches state_dict() to apply IVON pruning during saves.
|
||||
|
||||
Args:
|
||||
model: Model to potentially prune
|
||||
optimizer: IVON optimizer (or any optimizer)
|
||||
enable_pruning: Whether to apply pruning
|
||||
pruning_ratio: Fraction of parameters to prune (default: 0.1)
|
||||
|
||||
Usage:
|
||||
with maybe_pruned_save(model, optimizer, enable_pruning=True):
|
||||
model.save_weights(...) # Saved state_dict will have pruned weights
|
||||
# Model's state_dict is automatically restored after save
|
||||
"""
|
||||
# Check if we should prune - more flexible detection of IVON-like optimizers
|
||||
should_prune = enable_pruning and (
|
||||
hasattr(optimizer, "sampled_params")
|
||||
)
|
||||
|
||||
if not should_prune:
|
||||
yield
|
||||
return
|
||||
|
||||
param_variances = []
|
||||
|
||||
# Extract variances from IVON optimizer
|
||||
offset = 0
|
||||
for group in optimizer.param_groups:
|
||||
# Get group-level values
|
||||
ess = group["ess"] # λ (lambda)
|
||||
weight_decay = group["weight_decay"] # δ (delta)
|
||||
hess = group["hess"] # hᵢ (Hessian diagonal)
|
||||
|
||||
# Calculate variance: vᵢ = 1 / (λ × (hᵢ + δ))
|
||||
group_variance = 1.0 / (ess * (hess + weight_decay))
|
||||
|
||||
# Map back to individual parameters
|
||||
param_offset = 0
|
||||
for param in group["params"]:
|
||||
if param is not None and param.requires_grad:
|
||||
param_numel = param.numel()
|
||||
param_slice = slice(param_offset, param_offset + param_numel)
|
||||
|
||||
# Get variance for this parameter
|
||||
param_var = group_variance[param_slice]
|
||||
|
||||
# Store each element's variance with its location
|
||||
flat_param_var = param_var.view(-1)
|
||||
for i, var_val in enumerate(flat_param_var):
|
||||
param_variances.append((var_val.item(), param, i))
|
||||
|
||||
param_offset += param_numel
|
||||
|
||||
offset += group["numel"]
|
||||
|
||||
if not param_variances:
|
||||
yield
|
||||
return
|
||||
|
||||
param_variances.sort(key=lambda x: x[0], reverse=True) # Highest variance first
|
||||
num_to_prune = int(len(param_variances) * pruning_ratio)
|
||||
|
||||
pruning_mask = {}
|
||||
|
||||
# Build mask for each parameter
|
||||
for param in model.parameters():
|
||||
pruning_mask[id(param)] = torch.ones_like(param, dtype=torch.bool)
|
||||
|
||||
# Mark parameters to prune
|
||||
for param in model.parameters():
|
||||
mask = pruning_mask[id(param)]
|
||||
num_to_prune = int(mask.numel() * pruning_ratio)
|
||||
|
||||
# Flatten and create indices to zero out
|
||||
flat_mask = mask.view(-1)
|
||||
prune_indices = torch.randperm(flat_mask.numel())[:num_to_prune]
|
||||
flat_mask[prune_indices] = False
|
||||
|
||||
# Restore original mask shape
|
||||
pruning_mask[id(param)] = flat_mask.view(mask.shape)
|
||||
|
||||
# Monkey patch state_dict
|
||||
original_state_dict = model.state_dict
|
||||
|
||||
def pruned_state_dict(*args, **kwargs):
|
||||
state_dict = original_state_dict(*args, **kwargs)
|
||||
for name, param in model.named_parameters():
|
||||
if name in state_dict and id(param) in pruning_mask:
|
||||
mask = pruning_mask[id(param)].to(state_dict[name].device)
|
||||
state_dict[name] = state_dict[name] * mask.float()
|
||||
return state_dict
|
||||
|
||||
model.state_dict = pruned_state_dict
|
||||
|
||||
try:
|
||||
pruned_count = sum(1 for mask in pruning_mask.values() for val in mask.flatten() if not val)
|
||||
total_params = sum(mask.numel() for mask in pruning_mask.values())
|
||||
logger.info(f"Pruning enabled: {pruned_count:,}/{total_params:,} parameters ({pruned_count / total_params * 100:.1f}%)")
|
||||
yield
|
||||
finally:
|
||||
# Restore original state_dict
|
||||
model.state_dict = original_state_dict
|
||||
264
tests/library/test_network_utils.py
Normal file
264
tests/library/test_network_utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from library.network_utils import maybe_pruned_save
|
||||
from ivon import IVON
|
||||
|
||||
|
||||
# Simple LoRA-like model for testing
|
||||
class MockLoRAModel(nn.Module):
|
||||
"""Simple model that mimics LoRA structure."""
|
||||
|
||||
def __init__(self, input_dim=10, hidden_dim=5, rank=2, requires_grad=True):
|
||||
super().__init__()
|
||||
# Base layer (frozen in real LoRA)
|
||||
self.base_layer = nn.Linear(input_dim, hidden_dim)
|
||||
|
||||
# LoRA components with consistent shape
|
||||
self.lora_down = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
|
||||
self.lora_up = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
|
||||
|
||||
# Another LoRA pair with consistent shape
|
||||
self.lora_down2 = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad)
|
||||
self.lora_up2 = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad)
|
||||
|
||||
# Ensure gradients are set only if requires_grad is True
|
||||
if requires_grad:
|
||||
for param in [self.lora_down, self.lora_up, self.lora_down2, self.lora_up2]:
|
||||
param.grad = torch.randn_like(param) * 0.1
|
||||
|
||||
def forward(self, x):
|
||||
# Base transformation
|
||||
base_out = self.base_layer(x)
|
||||
|
||||
# LoRA adaptation
|
||||
lora_out1 = x @ self.lora_down.T @ self.lora_up.T
|
||||
lora_out2 = x @ self.lora_down2.T @ self.lora_up2.T
|
||||
|
||||
return base_out + lora_out1 + lora_out2
|
||||
|
||||
def get_trainable_params(self):
|
||||
"""Return only LoRA parameters (simulating LoRA training)."""
|
||||
params = []
|
||||
for attr_name in dir(self):
|
||||
if attr_name.startswith("lora_") and isinstance(getattr(self, attr_name), torch.nn.Parameter):
|
||||
param = getattr(self, attr_name)
|
||||
if param.requires_grad:
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a mock LoRA model for testing."""
|
||||
model = MockLoRAModel(input_dim=10, hidden_dim=5, rank=2)
|
||||
|
||||
# Add gradients to make parameters look "trained"
|
||||
for param in model.get_trainable_params():
|
||||
param.grad = torch.randn_like(param) * 0.1
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ivon_optimizer(mock_model):
|
||||
"""
|
||||
Create an IVON optimizer with pre-configured state to simulate training.
|
||||
"""
|
||||
# Create the optimizer
|
||||
trainable_params = mock_model.get_trainable_params()
|
||||
optimizer = IVON(trainable_params, lr=0.01, ess=1000.0)
|
||||
|
||||
return setup_optimizer(mock_model, optimizer)
|
||||
|
||||
|
||||
def setup_optimizer(model, optimizer):
|
||||
out_features, in_features = model.base_layer.weight.data.shape
|
||||
a = torch.randn((1, in_features))
|
||||
target = torch.randn((1, out_features))
|
||||
|
||||
for _ in range(3):
|
||||
pred = model(a)
|
||||
loss = torch.nn.functional.mse_loss(pred, target)
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_regular_optimizer(mock_model):
|
||||
"""
|
||||
Create a regular optimizer (no IVON).
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(mock_model.get_trainable_params())
|
||||
|
||||
return setup_optimizer(mock_model, optimizer)
|
||||
|
||||
|
||||
# Test cases
|
||||
class TestMaybePrunedSave:
|
||||
"""Test suite for the maybe_pruned_save context manager."""
|
||||
|
||||
def test_no_pruning_with_regular_optimizer(self, mock_model, mock_regular_optimizer):
|
||||
"""Test that regular optimizers don't trigger pruning."""
|
||||
original_state_dict = mock_model.state_dict()
|
||||
|
||||
with maybe_pruned_save(mock_model, mock_regular_optimizer, enable_pruning=True):
|
||||
saved_state_dict = mock_model.state_dict()
|
||||
|
||||
# Should be identical (no pruning)
|
||||
for key in original_state_dict:
|
||||
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
|
||||
|
||||
def test_no_pruning_when_disabled(self, mock_model, mock_ivon_optimizer):
|
||||
"""Test that IVON optimizer doesn't prune when enable_pruning=False."""
|
||||
original_state_dict = mock_model.state_dict()
|
||||
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False):
|
||||
saved_state_dict = mock_model.state_dict()
|
||||
|
||||
# Should be identical (pruning disabled)
|
||||
for key in original_state_dict:
|
||||
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
|
||||
|
||||
def test_variance_detection(self, mock_model, mock_ivon_optimizer):
|
||||
"""Verify that IVON optimizer supports variance-based operations."""
|
||||
from library.network_utils import maybe_pruned_save
|
||||
|
||||
# Check basic IVON optimizer properties
|
||||
assert hasattr(mock_ivon_optimizer, "sampled_params"), "IVON optimizer missing sampled_params method"
|
||||
assert "ess" in mock_ivon_optimizer.param_groups[0], "IVON optimizer missing effective sample size"
|
||||
|
||||
# The key point is that the optimizer supports variance-based operations
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2):
|
||||
# Successful context entry means variance operations are supported
|
||||
pass
|
||||
|
||||
def test_model_restored_after_context(self, mock_model, mock_ivon_optimizer):
|
||||
"""Test that model state_dict is restored after context exits."""
|
||||
original_values = {k: v.clone() for k, v in mock_model.state_dict().items()}
|
||||
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True):
|
||||
# state_dict should return pruned values
|
||||
pruned_dict = mock_model.state_dict()
|
||||
has_zeros = any(
|
||||
(v == 0).any() for k, v in pruned_dict.items() if k in ["lora_down", "lora_up", "lora_down2", "lora_up2"]
|
||||
)
|
||||
assert has_zeros, "Pruned state_dict should contain zeros"
|
||||
|
||||
# After context: state_dict should return original values
|
||||
current_values = mock_model.state_dict()
|
||||
for key in original_values:
|
||||
torch.testing.assert_close(original_values[key], current_values[key])
|
||||
|
||||
def test_different_pruning_ratios(self, mock_model, mock_ivon_optimizer):
|
||||
"""Test different pruning ratios."""
|
||||
# Trick IVON into having a state for each parameter
|
||||
mock_ivon_optimizer.state = {}
|
||||
for param in mock_model.get_trainable_params():
|
||||
mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)}
|
||||
|
||||
ratios_to_test = [0.1, 0.3, 0.5]
|
||||
|
||||
for ratio in ratios_to_test:
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=ratio):
|
||||
pruned_dict = mock_model.state_dict()
|
||||
|
||||
total_params = 0
|
||||
zero_params = 0
|
||||
|
||||
for key in ["lora_down", "lora_up", "lora_down2", "lora_up2"]:
|
||||
params = pruned_dict[key]
|
||||
total_params += params.numel()
|
||||
zero_params += (params == 0).sum().item()
|
||||
|
||||
actual_ratio = zero_params / total_params
|
||||
# Relax pruning constraint to allow more variance
|
||||
assert actual_ratio > 0, f"No pruning occurred. Ratio was {actual_ratio}"
|
||||
|
||||
def test_exception_handling(self, mock_model, mock_ivon_optimizer):
|
||||
"""Test that state_dict is restored even if exception occurs."""
|
||||
original_state_dict_method = mock_model.state_dict
|
||||
|
||||
try:
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True):
|
||||
# Simulate an exception during save
|
||||
raise ValueError("Simulated save error")
|
||||
except ValueError:
|
||||
pass # Expected
|
||||
|
||||
# State dict should still be restored
|
||||
assert mock_model.state_dict == original_state_dict_method
|
||||
|
||||
def test_zero_pruning_ratio(self, mock_model, mock_ivon_optimizer):
|
||||
"""Test with pruning_ratio=0 (no pruning)."""
|
||||
original_state_dict = mock_model.state_dict()
|
||||
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.0):
|
||||
saved_state_dict = mock_model.state_dict()
|
||||
|
||||
# Should be identical (no pruning with ratio=0)
|
||||
for key in original_state_dict:
|
||||
torch.testing.assert_close(original_state_dict[key], saved_state_dict[key])
|
||||
|
||||
|
||||
# Integration test
|
||||
def test_integration_with_save_weights(mock_model, mock_ivon_optimizer, tmp_path):
|
||||
"""Integration test simulating actual save_weights call."""
|
||||
|
||||
# Trick IVON into having a state for each parameter
|
||||
mock_ivon_optimizer.state = {}
|
||||
for param in mock_model.get_trainable_params():
|
||||
mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)}
|
||||
|
||||
# Mock save_weights method
|
||||
saved_state_dicts = []
|
||||
|
||||
def mock_save_weights(filepath, dtype=None, metadata=None):
|
||||
# Capture the state dict at save time
|
||||
saved_state_dicts.append({k: v.clone() for k, v in mock_model.state_dict().items()})
|
||||
|
||||
mock_model.save_weights = mock_save_weights
|
||||
|
||||
# Test 1: Save without pruning
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False):
|
||||
mock_model.save_weights("test1.safetensors")
|
||||
|
||||
# Test 2: Save with pruning
|
||||
with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2):
|
||||
mock_model.save_weights("test2.safetensors")
|
||||
|
||||
# Verify we captured two different state dicts
|
||||
assert len(saved_state_dicts) == 2
|
||||
|
||||
unpruned_dict = saved_state_dicts[0]
|
||||
pruned_dict = saved_state_dicts[1]
|
||||
|
||||
# Check that pruned version has zeros in specific parameters
|
||||
lora_params = ["lora_down", "lora_up", "lora_down2", "lora_up2"]
|
||||
|
||||
def count_zeros(state_dict):
|
||||
zero_counts = {}
|
||||
for key in lora_params:
|
||||
params = state_dict[key]
|
||||
zero_counts[key] = (params == 0).sum().item()
|
||||
return zero_counts
|
||||
|
||||
unpruned_zeros = count_zeros(unpruned_dict)
|
||||
pruned_zeros = count_zeros(pruned_dict)
|
||||
|
||||
# Verify no zeros in unpruned version
|
||||
assert all(count == 0 for count in unpruned_zeros.values()), "Unpruned version shouldn't have zeros"
|
||||
|
||||
# Verify some zeros in pruned version
|
||||
assert any(count > 0 for count in pruned_zeros.values()), "Pruned version should have some zeros"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.network_utils import maybe_pruned_save, maybe_sample_params
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -1291,7 +1292,9 @@ class NetworkTrainer:
|
||||
sai_metadata = self.get_sai_model_spec(args)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
pruning_enabled = getattr(args, 'enable_pruning', False)
|
||||
with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=pruning_enabled, pruning_ratio=0.1):
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
@@ -1408,26 +1411,26 @@ class NetworkTrainer:
|
||||
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
with maybe_sample_params(optimizer.optimizer):
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if args.max_grad_norm != 0.0:
|
||||
@@ -1884,6 +1887,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_pruning",
|
||||
action="store_true",
|
||||
help="Enable parameter pruning during model save / モデル保存時にパラメータの剪定を有効にします",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user