mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Compare commits
8 Commits
88a4442b7e
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
308a0cc9fc | ||
|
|
7e60e163c1 | ||
|
|
a8f5c222e0 | ||
|
|
1d588d6cb6 | ||
|
|
a7d35701a0 | ||
|
|
8da05a10dc | ||
|
|
197b129284 | ||
|
|
51435f1718 |
@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
|||||||
|
|
||||||
### 更新履歴
|
### 更新履歴
|
||||||
|
|
||||||
|
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
|
||||||
|
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
|
||||||
|
|
||||||
- **Version 0.10.3 (2026-04-02):**
|
- **Version 0.10.3 (2026-04-02):**
|
||||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
|
|||||||
|
|
||||||
### Change History
|
### Change History
|
||||||
|
|
||||||
|
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
|
||||||
|
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
|
||||||
|
|
||||||
- **Version 0.10.3 (2026-04-02):**
|
- **Version 0.10.3 (2026-04-02):**
|
||||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
has_ipex = True
|
has_ipex = True
|
||||||
@@ -8,7 +9,7 @@ except Exception:
|
|||||||
has_ipex = False
|
has_ipex = False
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
|
|
||||||
torch_version = float(torch.__version__[:3])
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.__path__ = torch.xpu.__path__
|
torch.cuda.__path__ = torch.xpu.__path__
|
||||||
torch.cuda.set_stream = torch.xpu.set_stream
|
torch.cuda.set_stream = torch.xpu.set_stream
|
||||||
torch.cuda.torch = torch.xpu.torch
|
torch.cuda.torch = torch.xpu.torch
|
||||||
torch.cuda.Union = torch.xpu.Union
|
|
||||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||||
torch.cuda.__package__ = torch.xpu.__package__
|
torch.cuda.__package__ = torch.xpu.__package__
|
||||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||||
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||||
torch.cuda.random = torch.xpu.random
|
torch.cuda.random = torch.xpu.random
|
||||||
torch.cuda._device = torch.xpu._device
|
|
||||||
torch.cuda.__name__ = torch.xpu.__name__
|
torch.cuda.__name__ = torch.xpu.__name__
|
||||||
torch.cuda._device_t = torch.xpu._device_t
|
|
||||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||||
torch.cuda.__file__ = torch.xpu.__file__
|
torch.cuda.__file__ = torch.xpu.__file__
|
||||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||||
|
|
||||||
if torch_version < 2.3:
|
if torch_version < version.parse("2.3"):
|
||||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||||
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.threading = torch.xpu.threading
|
torch.cuda.threading = torch.xpu.threading
|
||||||
torch.cuda.traceback = torch.xpu.traceback
|
torch.cuda.traceback = torch.xpu.traceback
|
||||||
|
|
||||||
if torch_version < 2.5:
|
if torch_version < version.parse("2.5"):
|
||||||
torch.cuda.os = torch.xpu.os
|
torch.cuda.os = torch.xpu.os
|
||||||
torch.cuda.Device = torch.xpu.Device
|
torch.cuda.Device = torch.xpu.Device
|
||||||
torch.cuda.warnings = torch.xpu.warnings
|
torch.cuda.warnings = torch.xpu.warnings
|
||||||
torch.cuda.classproperty = torch.xpu.classproperty
|
torch.cuda.classproperty = torch.xpu.classproperty
|
||||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||||
|
|
||||||
if torch_version < 2.7:
|
if torch_version < version.parse("2.7"):
|
||||||
torch.cuda.Tuple = torch.xpu.Tuple
|
torch.cuda.Tuple = torch.xpu.Tuple
|
||||||
torch.cuda.List = torch.xpu.List
|
torch.cuda.List = torch.xpu.List
|
||||||
|
|
||||||
|
if torch_version < version.parse("2.11"):
|
||||||
|
torch.cuda._device_t = torch.xpu._device_t
|
||||||
|
torch.cuda._device = torch.xpu._device
|
||||||
|
torch.cuda.Union = torch.xpu.Union
|
||||||
|
|
||||||
|
|
||||||
# Memory:
|
# Memory:
|
||||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||||
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||||
|
|
||||||
# C
|
# C
|
||||||
if torch_version < 2.3:
|
if torch_version < version.parse("2.3"):
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||||
ipex._C._DeviceProperties.major = 12
|
ipex._C._DeviceProperties.major = 12
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn, Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class AID(nn.Module):
|
|
||||||
def __init__(self, p=0.9):
|
|
||||||
super().__init__()
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
if self.training:
|
|
||||||
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
|
||||||
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
|
||||||
return x * (pos_mask + neg_mask)
|
|
||||||
else:
|
|
||||||
pos_part = (x >= 0) * x * self.p
|
|
||||||
neg_part = (x < 0) * x * (1 - self.p)
|
|
||||||
return pos_part + neg_part
|
|
||||||
@@ -17,7 +17,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import re
|
import re
|
||||||
from library.model_utils import AID
|
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
@@ -31,21 +30,6 @@ NUM_DOUBLE_BLOCKS = 19
|
|||||||
NUM_SINGLE_BLOCKS = 38
|
NUM_SINGLE_BLOCKS = 38
|
||||||
|
|
||||||
|
|
||||||
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
|
|
||||||
# Normalize the position to 0-1 range
|
|
||||||
normalized_pos = block_id / total_blocks
|
|
||||||
|
|
||||||
# Shift the sine curve to only use the first 3/4 of the cycle
|
|
||||||
# This gives us: start at 0, peak in the middle, end around 0.7
|
|
||||||
phase_shift = shift * math.pi
|
|
||||||
sine_value = math.sin(normalized_pos * phase_shift)
|
|
||||||
|
|
||||||
# Scale to our desired peak of 0.9
|
|
||||||
result = peak * sine_value
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
@@ -61,7 +45,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
dropout=None,
|
dropout=None,
|
||||||
rank_dropout=None,
|
rank_dropout=None,
|
||||||
module_dropout=None,
|
module_dropout=None,
|
||||||
aid_dropout=None,
|
|
||||||
split_dims: Optional[List[int]] = None,
|
split_dims: Optional[List[int]] = None,
|
||||||
ggpo_beta: Optional[float] = None,
|
ggpo_beta: Optional[float] = None,
|
||||||
ggpo_sigma: Optional[float] = None,
|
ggpo_sigma: Optional[float] = None,
|
||||||
@@ -124,13 +107,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
self.aid = (
|
|
||||||
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
|
|
||||||
) # AID activation
|
|
||||||
|
|
||||||
if aid_dropout is not None:
|
|
||||||
self.register_buffer("aid_p", torch.tensor(aid_dropout))
|
|
||||||
|
|
||||||
self.ggpo_sigma = ggpo_sigma
|
self.ggpo_sigma = ggpo_sigma
|
||||||
self.ggpo_beta = ggpo_beta
|
self.ggpo_beta = ggpo_beta
|
||||||
|
|
||||||
@@ -182,9 +158,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
lx = self.lora_up(lx)
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
# Activation by Interval-wise Dropout
|
|
||||||
lx = self.aid(lx)
|
|
||||||
|
|
||||||
# LoRA Gradient-Guided Perturbation Optimization
|
# LoRA Gradient-Guided Perturbation Optimization
|
||||||
if (
|
if (
|
||||||
self.training
|
self.training
|
||||||
@@ -581,9 +554,6 @@ def create_network(
|
|||||||
module_dropout = kwargs.get("module_dropout", None)
|
module_dropout = kwargs.get("module_dropout", None)
|
||||||
if module_dropout is not None:
|
if module_dropout is not None:
|
||||||
module_dropout = float(module_dropout)
|
module_dropout = float(module_dropout)
|
||||||
aid_dropout = kwargs.get("aid_dropout", None)
|
|
||||||
if aid_dropout is not None:
|
|
||||||
aid_dropout = float(aid_dropout)
|
|
||||||
|
|
||||||
# single or double blocks
|
# single or double blocks
|
||||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||||
@@ -660,7 +630,6 @@ def create_network(
|
|||||||
dropout=neuron_dropout,
|
dropout=neuron_dropout,
|
||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
aid_dropout=aid_dropout,
|
|
||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
train_blocks=train_blocks,
|
train_blocks=train_blocks,
|
||||||
@@ -761,7 +730,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dropout: Optional[float] = None,
|
dropout: Optional[float] = None,
|
||||||
rank_dropout: Optional[float] = None,
|
rank_dropout: Optional[float] = None,
|
||||||
module_dropout: Optional[float] = None,
|
module_dropout: Optional[float] = None,
|
||||||
aid_dropout: Optional[float] = None,
|
|
||||||
conv_lora_dim: Optional[int] = None,
|
conv_lora_dim: Optional[int] = None,
|
||||||
conv_alpha: Optional[float] = None,
|
conv_alpha: Optional[float] = None,
|
||||||
module_class: Type[object] = LoRAModule,
|
module_class: Type[object] = LoRAModule,
|
||||||
@@ -790,7 +758,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
self.aid_dropout = aid_dropout
|
|
||||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||||
self.split_qkv = split_qkv
|
self.split_qkv = split_qkv
|
||||||
self.train_t5xxl = train_t5xxl
|
self.train_t5xxl = train_t5xxl
|
||||||
@@ -867,7 +834,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
aid_dropout_p = None
|
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
# モジュール指定あり
|
# モジュール指定あり
|
||||||
@@ -903,21 +869,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||||
dim = d # may be 0 for skip
|
dim = d # may be 0 for skip
|
||||||
break
|
break
|
||||||
is_double = False
|
|
||||||
if "double" in lora_name:
|
|
||||||
is_double = True
|
|
||||||
is_single = False
|
|
||||||
if "single" in lora_name:
|
|
||||||
is_single = True
|
|
||||||
block_index = None
|
|
||||||
if is_flux and dim and (is_double or is_single):
|
|
||||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
|
||||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
|
||||||
|
|
||||||
if block_index is not None and aid_dropout is not None:
|
|
||||||
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
|
|
||||||
aid_dropout_p = get_point_on_curve(
|
|
||||||
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
is_flux
|
is_flux
|
||||||
@@ -928,6 +879,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
and ("double" in lora_name or "single" in lora_name)
|
and ("double" in lora_name or "single" in lora_name)
|
||||||
):
|
):
|
||||||
|
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||||
|
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||||
if (
|
if (
|
||||||
"double" in lora_name
|
"double" in lora_name
|
||||||
and self.train_double_block_indices is not None
|
and self.train_double_block_indices is not None
|
||||||
@@ -968,7 +921,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
|
|
||||||
split_dims=split_dims,
|
split_dims=split_dims,
|
||||||
ggpo_beta=ggpo_beta,
|
ggpo_beta=ggpo_beta,
|
||||||
ggpo_sigma=ggpo_sigma,
|
ggpo_sigma=ggpo_sigma,
|
||||||
@@ -1165,7 +1117,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||||
|
|
||||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||||
aid_p = state_dict.pop(f"{lora_name}.aid_p")
|
|
||||||
|
|
||||||
# merge down weight
|
# merge down weight
|
||||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||||
@@ -1181,7 +1132,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||||
new_state_dict[f"{lora_name}.aid_p"] = aid_p
|
|
||||||
|
|
||||||
# print(
|
# print(
|
||||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from library.model_utils import AID
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def input_tensor():
|
|
||||||
# Create a tensor with positive and negative values
|
|
||||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def test_aid_forward_train_mode(input_tensor):
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
aid.train()
|
|
||||||
|
|
||||||
# Run several forward passes to test stochastic behavior
|
|
||||||
results = []
|
|
||||||
for _ in range(10):
|
|
||||||
output = aid(input_tensor)
|
|
||||||
results.append(output.detach().clone())
|
|
||||||
|
|
||||||
# Test that outputs vary (stochastic behavior)
|
|
||||||
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
|
|
||||||
assert not all_equal, "All outputs are identical, expected variability in training mode"
|
|
||||||
|
|
||||||
# Test shape preservation
|
|
||||||
assert results[0].shape == input_tensor.shape
|
|
||||||
|
|
||||||
def test_aid_forward_eval_mode(input_tensor):
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
aid.eval()
|
|
||||||
|
|
||||||
output = aid(input_tensor)
|
|
||||||
|
|
||||||
# Test deterministic behavior
|
|
||||||
output2 = aid(input_tensor)
|
|
||||||
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
|
|
||||||
|
|
||||||
# Test correct transformation
|
|
||||||
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
|
|
||||||
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
|
|
||||||
|
|
||||||
def test_aid_gradient_flow(input_tensor):
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
aid.train()
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
output = aid(input_tensor)
|
|
||||||
|
|
||||||
# Check gradient flow
|
|
||||||
assert output.requires_grad, "Output lost gradient tracking"
|
|
||||||
|
|
||||||
# Compute loss and backpropagate
|
|
||||||
loss = output.sum()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Verify gradients were computed
|
|
||||||
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
|
|
||||||
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
|
|
||||||
|
|
||||||
def test_aid_extreme_p_values():
|
|
||||||
# Test with p=1.0 (only positive values pass through)
|
|
||||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
|
||||||
aid = AID(p=1.0)
|
|
||||||
aid.eval()
|
|
||||||
|
|
||||||
output = aid(x)
|
|
||||||
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
|
|
||||||
assert torch.allclose(output, expected), "Failed with p=1.0"
|
|
||||||
|
|
||||||
# Test with p=0.0 (only negative values pass through)
|
|
||||||
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
|
|
||||||
aid = AID(p=0.0)
|
|
||||||
aid.eval()
|
|
||||||
|
|
||||||
output = aid(x)
|
|
||||||
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
|
|
||||||
assert torch.allclose(output, expected), "Failed with p=0.0"
|
|
||||||
|
|
||||||
def test_aid_with_all_positive_values():
|
|
||||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
aid.train()
|
|
||||||
|
|
||||||
# Run forward passes and check that only positive values are affected
|
|
||||||
output = aid(x)
|
|
||||||
|
|
||||||
# Backprop should work
|
|
||||||
loss = output.sum()
|
|
||||||
loss.backward()
|
|
||||||
assert x.grad is not None, "No gradients were recorded for all-positive input"
|
|
||||||
|
|
||||||
def test_aid_with_all_negative_values():
|
|
||||||
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
aid.train()
|
|
||||||
|
|
||||||
# Run forward passes and check that only negative values are affected
|
|
||||||
output = aid(x)
|
|
||||||
|
|
||||||
# Backprop should work
|
|
||||||
loss = output.sum()
|
|
||||||
loss.backward()
|
|
||||||
assert x.grad is not None, "No gradients were recorded for all-negative input"
|
|
||||||
|
|
||||||
def test_aid_with_zero_values():
|
|
||||||
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
|
|
||||||
# Test training mode
|
|
||||||
aid.train()
|
|
||||||
output = aid(x)
|
|
||||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
|
||||||
|
|
||||||
# Test eval mode
|
|
||||||
aid.eval()
|
|
||||||
output = aid(x)
|
|
||||||
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
|
|
||||||
|
|
||||||
def test_aid_integration_with_linear_layer():
|
|
||||||
# Test AID's compatibility with a linear layer
|
|
||||||
linear = nn.Linear(5, 2)
|
|
||||||
aid = AID(p=0.9)
|
|
||||||
|
|
||||||
model = nn.Sequential(linear, aid)
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
x = torch.randn(3, 5, requires_grad=True)
|
|
||||||
output = model(x)
|
|
||||||
|
|
||||||
# Check that gradients flow through the whole model
|
|
||||||
loss = output.sum()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
assert linear.weight.grad is not None, "No gradients for linear layer weights"
|
|
||||||
assert x.grad is not None, "No gradients for input tensor"
|
|
||||||
Reference in New Issue
Block a user