diff --git a/library/model_utils.py b/library/model_utils.py new file mode 100644 index 00000000..87c2d04a --- /dev/null +++ b/library/model_utils.py @@ -0,0 +1,18 @@ +import torch +from torch import nn, Tensor + + +class AID(nn.Module): + def __init__(self, p=0.9): + super().__init__() + self.p = p + + def forward(self, x: Tensor): + if self.training: + pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + return x * (pos_mask + neg_mask) + else: + pos_part = (x >= 0) * x * self.p + neg_part = (x < 0) * x * (1 - self.p) + return pos_part + neg_part diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 947733fe..875f4001 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,6 +17,7 @@ import numpy as np import torch from torch import Tensor import re +from library.model_utils import AID from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 +def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75): + # Normalize the position to 0-1 range + normalized_pos = block_id / total_blocks + + # Shift the sine curve to only use the first 3/4 of the cycle + # This gives us: start at 0, peak in the middle, end around 0.7 + phase_shift = shift * math.pi + sine_value = math.sin(normalized_pos * phase_shift) + + # Scale to our desired peak of 0.9 + result = peak * sine_value + + return result + + class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + aid_dropout=None, split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, @@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid = ( + AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity() + ) # AID activation + + if aid_dropout is not None: + self.register_buffer("aid_p", torch.tensor(aid_dropout)) + self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) + # Activation by Interval-wise Dropout + lx = self.aid(lx) + # LoRA Gradient-Guided Perturbation Optimization if ( self.training @@ -554,6 +581,9 @@ def create_network( module_dropout = kwargs.get("module_dropout", None) if module_dropout is not None: module_dropout = float(module_dropout) + aid_dropout = kwargs.get("aid_dropout", None) + if aid_dropout is not None: + aid_dropout = float(aid_dropout) # single or double blocks train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" @@ -630,6 +660,7 @@ def create_network( dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, @@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module): dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, + aid_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, @@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.aid_dropout = aid_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl @@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module): dim = None alpha = None + aid_dropout_p = None if modules_dim is not None: # モジュール指定あり @@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module): if d is not None and all([id in lora_name for id in identifier[i]]): dim = d # may be 0 for skip break + is_double = False + if "double" in lora_name: + is_double = True + is_single = False + if "single" in lora_name: + is_single = True + block_index = None + if is_flux and dim and (is_double or is_single): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + + if block_index is not None and aid_dropout is not None: + all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS + aid_dropout_p = get_point_on_curve( + all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout) if ( is_flux @@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module): ) and ("double" in lora_name or "single" in lora_name) ): - # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." - block_index = int(lora_name.split("_")[4]) # bit dirty if ( "double" in lora_name and self.train_double_block_indices is not None @@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout, split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, @@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module): up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] alpha = state_dict.pop(f"{lora_name}.alpha") + aid_p = state_dict.pop(f"{lora_name}.aid_p") # merge down weight down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) @@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module): new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight new_state_dict[f"{lora_name}.alpha"] = alpha + new_state_dict[f"{lora_name}.aid_p"] = aid_p # print( # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" diff --git a/tests/library/test_model_utils.py b/tests/library/test_model_utils.py new file mode 100644 index 00000000..c432266e --- /dev/null +++ b/tests/library/test_model_utils.py @@ -0,0 +1,137 @@ +import pytest +import torch +from library.model_utils import AID +from torch import nn +import torch.nn.functional as F + +@pytest.fixture +def input_tensor(): + # Create a tensor with positive and negative values + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + return x + +def test_aid_forward_train_mode(input_tensor): + aid = AID(p=0.9) + aid.train() + + # Run several forward passes to test stochastic behavior + results = [] + for _ in range(10): + output = aid(input_tensor) + results.append(output.detach().clone()) + + # Test that outputs vary (stochastic behavior) + all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10)) + assert not all_equal, "All outputs are identical, expected variability in training mode" + + # Test shape preservation + assert results[0].shape == input_tensor.shape + +def test_aid_forward_eval_mode(input_tensor): + aid = AID(p=0.9) + aid.eval() + + output = aid(input_tensor) + + # Test deterministic behavior + output2 = aid(input_tensor) + assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode" + + # Test correct transformation + expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1 + assert torch.allclose(output, expected), "Incorrect evaluation mode transformation" + +def test_aid_gradient_flow(input_tensor): + aid = AID(p=0.9) + aid.train() + + # Forward pass + output = aid(input_tensor) + + # Check gradient flow + assert output.requires_grad, "Output lost gradient tracking" + + # Compute loss and backpropagate + loss = output.sum() + loss.backward() + + # Verify gradients were computed + assert input_tensor.grad is not None, "No gradients were recorded for input tensor" + assert torch.any(input_tensor.grad != 0), "Gradients are all zeros" + +def test_aid_extreme_p_values(): + # Test with p=1.0 (only positive values pass through) + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + aid = AID(p=1.0) + aid.eval() + + output = aid(x) + expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) + assert torch.allclose(output, expected), "Failed with p=1.0" + + # Test with p=0.0 (only negative values pass through) + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + aid = AID(p=0.0) + aid.eval() + + output = aid(x) + expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0]) + assert torch.allclose(output, expected), "Failed with p=0.0" + +def test_aid_with_all_positive_values(): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True) + aid = AID(p=0.9) + aid.train() + + # Run forward passes and check that only positive values are affected + output = aid(x) + + # Backprop should work + loss = output.sum() + loss.backward() + assert x.grad is not None, "No gradients were recorded for all-positive input" + +def test_aid_with_all_negative_values(): + x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True) + aid = AID(p=0.9) + aid.train() + + # Run forward passes and check that only negative values are affected + output = aid(x) + + # Backprop should work + loss = output.sum() + loss.backward() + assert x.grad is not None, "No gradients were recorded for all-negative input" + +def test_aid_with_zero_values(): + x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True) + aid = AID(p=0.9) + + # Test training mode + aid.train() + output = aid(x) + assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input" + + # Test eval mode + aid.eval() + output = aid(x) + assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input" + +def test_aid_integration_with_linear_layer(): + # Test AID's compatibility with a linear layer + linear = nn.Linear(5, 2) + aid = AID(p=0.9) + + model = nn.Sequential(linear, aid) + model.train() + + x = torch.randn(3, 5, requires_grad=True) + output = model(x) + + # Check that gradients flow through the whole model + loss = output.sum() + loss.backward() + + assert linear.weight.grad is not None, "No gradients for linear layer weights" + assert x.grad is not None, "No gradients for input tensor"