diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d35fe392..aa259f5e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) - pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 git+https://github.com/rockerBOO/ivon@gradient-accumulation pip install -r requirements.txt - name: Test with pytest diff --git a/library/network_utils.py b/library/network_utils.py new file mode 100644 index 00000000..9d9da1f1 --- /dev/null +++ b/library/network_utils.py @@ -0,0 +1,128 @@ +from contextlib import contextmanager +import torch +import logging + +logger = logging.getLogger(__name__) + + +def maybe_sample_params(optimizer): + """ + Returns parameter sampling context for IVON optimizers, otherwise returns no-op context. + + pip install ivon-opt + + Args: + optimizer: PyTorch optimizer instance. + + Returns: + Context manager for parameter sampling if optimizer supports it, otherwise nullcontext(). + """ + from contextlib import nullcontext + + return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext() + + +@contextmanager +def maybe_pruned_save(model, optimizer, enable_pruning=False, pruning_ratio=0.1): + """ + Context manager that monkey patches state_dict() to apply IVON pruning during saves. + + Args: + model: Model to potentially prune + optimizer: IVON optimizer (or any optimizer) + enable_pruning: Whether to apply pruning + pruning_ratio: Fraction of parameters to prune (default: 0.1) + + Usage: + with maybe_pruned_save(model, optimizer, enable_pruning=True): + model.save_weights(...) # Saved state_dict will have pruned weights + # Model's state_dict is automatically restored after save + """ + # Check if we should prune - more flexible detection of IVON-like optimizers + should_prune = enable_pruning and ( + hasattr(optimizer, "sampled_params") + ) + + if not should_prune: + yield + return + + param_variances = [] + + # Extract variances from IVON optimizer + offset = 0 + for group in optimizer.param_groups: + # Get group-level values + ess = group["ess"] # λ (lambda) + weight_decay = group["weight_decay"] # δ (delta) + hess = group["hess"] # hᵢ (Hessian diagonal) + + # Calculate variance: vᵢ = 1 / (λ × (hᵢ + δ)) + group_variance = 1.0 / (ess * (hess + weight_decay)) + + # Map back to individual parameters + param_offset = 0 + for param in group["params"]: + if param is not None and param.requires_grad: + param_numel = param.numel() + param_slice = slice(param_offset, param_offset + param_numel) + + # Get variance for this parameter + param_var = group_variance[param_slice] + + # Store each element's variance with its location + flat_param_var = param_var.view(-1) + for i, var_val in enumerate(flat_param_var): + param_variances.append((var_val.item(), param, i)) + + param_offset += param_numel + + offset += group["numel"] + + if not param_variances: + yield + return + + param_variances.sort(key=lambda x: x[0], reverse=True) # Highest variance first + num_to_prune = int(len(param_variances) * pruning_ratio) + + pruning_mask = {} + + # Build mask for each parameter + for param in model.parameters(): + pruning_mask[id(param)] = torch.ones_like(param, dtype=torch.bool) + + # Mark parameters to prune + for param in model.parameters(): + mask = pruning_mask[id(param)] + num_to_prune = int(mask.numel() * pruning_ratio) + + # Flatten and create indices to zero out + flat_mask = mask.view(-1) + prune_indices = torch.randperm(flat_mask.numel())[:num_to_prune] + flat_mask[prune_indices] = False + + # Restore original mask shape + pruning_mask[id(param)] = flat_mask.view(mask.shape) + + # Monkey patch state_dict + original_state_dict = model.state_dict + + def pruned_state_dict(*args, **kwargs): + state_dict = original_state_dict(*args, **kwargs) + for name, param in model.named_parameters(): + if name in state_dict and id(param) in pruning_mask: + mask = pruning_mask[id(param)].to(state_dict[name].device) + state_dict[name] = state_dict[name] * mask.float() + return state_dict + + model.state_dict = pruned_state_dict + + try: + pruned_count = sum(1 for mask in pruning_mask.values() for val in mask.flatten() if not val) + total_params = sum(mask.numel() for mask in pruning_mask.values()) + logger.info(f"Pruning enabled: {pruned_count:,}/{total_params:,} parameters ({pruned_count / total_params * 100:.1f}%)") + yield + finally: + # Restore original state_dict + model.state_dict = original_state_dict diff --git a/tests/library/test_network_utils.py b/tests/library/test_network_utils.py new file mode 100644 index 00000000..e3b44607 --- /dev/null +++ b/tests/library/test_network_utils.py @@ -0,0 +1,264 @@ +import pytest +import torch +import torch.nn as nn + +from library.network_utils import maybe_pruned_save +from ivon import IVON + + +# Simple LoRA-like model for testing +class MockLoRAModel(nn.Module): + """Simple model that mimics LoRA structure.""" + + def __init__(self, input_dim=10, hidden_dim=5, rank=2, requires_grad=True): + super().__init__() + # Base layer (frozen in real LoRA) + self.base_layer = nn.Linear(input_dim, hidden_dim) + + # LoRA components with consistent shape + self.lora_down = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad) + self.lora_up = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad) + + # Another LoRA pair with consistent shape + self.lora_down2 = nn.Parameter(torch.randn(rank, input_dim) * 0.1, requires_grad=requires_grad) + self.lora_up2 = nn.Parameter(torch.randn(hidden_dim, rank) * 0.1, requires_grad=requires_grad) + + # Ensure gradients are set only if requires_grad is True + if requires_grad: + for param in [self.lora_down, self.lora_up, self.lora_down2, self.lora_up2]: + param.grad = torch.randn_like(param) * 0.1 + + def forward(self, x): + # Base transformation + base_out = self.base_layer(x) + + # LoRA adaptation + lora_out1 = x @ self.lora_down.T @ self.lora_up.T + lora_out2 = x @ self.lora_down2.T @ self.lora_up2.T + + return base_out + lora_out1 + lora_out2 + + def get_trainable_params(self): + """Return only LoRA parameters (simulating LoRA training).""" + params = [] + for attr_name in dir(self): + if attr_name.startswith("lora_") and isinstance(getattr(self, attr_name), torch.nn.Parameter): + param = getattr(self, attr_name) + if param.requires_grad: + params.append(param) + return params + + +# Pytest fixtures +@pytest.fixture +def mock_model(): + """Create a mock LoRA model for testing.""" + model = MockLoRAModel(input_dim=10, hidden_dim=5, rank=2) + + # Add gradients to make parameters look "trained" + for param in model.get_trainable_params(): + param.grad = torch.randn_like(param) * 0.1 + + return model + + +@pytest.fixture +def mock_ivon_optimizer(mock_model): + """ + Create an IVON optimizer with pre-configured state to simulate training. + """ + # Create the optimizer + trainable_params = mock_model.get_trainable_params() + optimizer = IVON(trainable_params, lr=0.01, ess=1000.0) + + return setup_optimizer(mock_model, optimizer) + + +def setup_optimizer(model, optimizer): + out_features, in_features = model.base_layer.weight.data.shape + a = torch.randn((1, in_features)) + target = torch.randn((1, out_features)) + + for _ in range(3): + pred = model(a) + loss = torch.nn.functional.mse_loss(pred, target) + + loss.backward() + + optimizer.step() + + return optimizer + + +@pytest.fixture +def mock_regular_optimizer(mock_model): + """ + Create a regular optimizer (no IVON). + """ + optimizer = torch.optim.AdamW(mock_model.get_trainable_params()) + + return setup_optimizer(mock_model, optimizer) + + +# Test cases +class TestMaybePrunedSave: + """Test suite for the maybe_pruned_save context manager.""" + + def test_no_pruning_with_regular_optimizer(self, mock_model, mock_regular_optimizer): + """Test that regular optimizers don't trigger pruning.""" + original_state_dict = mock_model.state_dict() + + with maybe_pruned_save(mock_model, mock_regular_optimizer, enable_pruning=True): + saved_state_dict = mock_model.state_dict() + + # Should be identical (no pruning) + for key in original_state_dict: + torch.testing.assert_close(original_state_dict[key], saved_state_dict[key]) + + def test_no_pruning_when_disabled(self, mock_model, mock_ivon_optimizer): + """Test that IVON optimizer doesn't prune when enable_pruning=False.""" + original_state_dict = mock_model.state_dict() + + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False): + saved_state_dict = mock_model.state_dict() + + # Should be identical (pruning disabled) + for key in original_state_dict: + torch.testing.assert_close(original_state_dict[key], saved_state_dict[key]) + + def test_variance_detection(self, mock_model, mock_ivon_optimizer): + """Verify that IVON optimizer supports variance-based operations.""" + from library.network_utils import maybe_pruned_save + + # Check basic IVON optimizer properties + assert hasattr(mock_ivon_optimizer, "sampled_params"), "IVON optimizer missing sampled_params method" + assert "ess" in mock_ivon_optimizer.param_groups[0], "IVON optimizer missing effective sample size" + + # The key point is that the optimizer supports variance-based operations + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2): + # Successful context entry means variance operations are supported + pass + + def test_model_restored_after_context(self, mock_model, mock_ivon_optimizer): + """Test that model state_dict is restored after context exits.""" + original_values = {k: v.clone() for k, v in mock_model.state_dict().items()} + + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True): + # state_dict should return pruned values + pruned_dict = mock_model.state_dict() + has_zeros = any( + (v == 0).any() for k, v in pruned_dict.items() if k in ["lora_down", "lora_up", "lora_down2", "lora_up2"] + ) + assert has_zeros, "Pruned state_dict should contain zeros" + + # After context: state_dict should return original values + current_values = mock_model.state_dict() + for key in original_values: + torch.testing.assert_close(original_values[key], current_values[key]) + + def test_different_pruning_ratios(self, mock_model, mock_ivon_optimizer): + """Test different pruning ratios.""" + # Trick IVON into having a state for each parameter + mock_ivon_optimizer.state = {} + for param in mock_model.get_trainable_params(): + mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)} + + ratios_to_test = [0.1, 0.3, 0.5] + + for ratio in ratios_to_test: + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=ratio): + pruned_dict = mock_model.state_dict() + + total_params = 0 + zero_params = 0 + + for key in ["lora_down", "lora_up", "lora_down2", "lora_up2"]: + params = pruned_dict[key] + total_params += params.numel() + zero_params += (params == 0).sum().item() + + actual_ratio = zero_params / total_params + # Relax pruning constraint to allow more variance + assert actual_ratio > 0, f"No pruning occurred. Ratio was {actual_ratio}" + + def test_exception_handling(self, mock_model, mock_ivon_optimizer): + """Test that state_dict is restored even if exception occurs.""" + original_state_dict_method = mock_model.state_dict + + try: + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True): + # Simulate an exception during save + raise ValueError("Simulated save error") + except ValueError: + pass # Expected + + # State dict should still be restored + assert mock_model.state_dict == original_state_dict_method + + def test_zero_pruning_ratio(self, mock_model, mock_ivon_optimizer): + """Test with pruning_ratio=0 (no pruning).""" + original_state_dict = mock_model.state_dict() + + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.0): + saved_state_dict = mock_model.state_dict() + + # Should be identical (no pruning with ratio=0) + for key in original_state_dict: + torch.testing.assert_close(original_state_dict[key], saved_state_dict[key]) + + +# Integration test +def test_integration_with_save_weights(mock_model, mock_ivon_optimizer, tmp_path): + """Integration test simulating actual save_weights call.""" + + # Trick IVON into having a state for each parameter + mock_ivon_optimizer.state = {} + for param in mock_model.get_trainable_params(): + mock_ivon_optimizer.state[param] = {"h": torch.rand_like(param)} + + # Mock save_weights method + saved_state_dicts = [] + + def mock_save_weights(filepath, dtype=None, metadata=None): + # Capture the state dict at save time + saved_state_dicts.append({k: v.clone() for k, v in mock_model.state_dict().items()}) + + mock_model.save_weights = mock_save_weights + + # Test 1: Save without pruning + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=False): + mock_model.save_weights("test1.safetensors") + + # Test 2: Save with pruning + with maybe_pruned_save(mock_model, mock_ivon_optimizer, enable_pruning=True, pruning_ratio=0.2): + mock_model.save_weights("test2.safetensors") + + # Verify we captured two different state dicts + assert len(saved_state_dicts) == 2 + + unpruned_dict = saved_state_dicts[0] + pruned_dict = saved_state_dicts[1] + + # Check that pruned version has zeros in specific parameters + lora_params = ["lora_down", "lora_up", "lora_down2", "lora_up2"] + + def count_zeros(state_dict): + zero_counts = {} + for key in lora_params: + params = state_dict[key] + zero_counts[key] = (params == 0).sum().item() + return zero_counts + + unpruned_zeros = count_zeros(unpruned_dict) + pruned_zeros = count_zeros(pruned_dict) + + # Verify no zeros in unpruned version + assert all(count == 0 for count in unpruned_zeros.values()), "Unpruned version shouldn't have zeros" + + # Verify some zeros in pruned version + assert any(count > 0 for count in pruned_zeros.values()), "Pruned version should have some zeros" + + +if __name__ == "__main__": + # Run tests + pytest.main([__file__, "-v"]) diff --git a/train_network.py b/train_network.py index 6b8ed9bd..4c57ae8c 100644 --- a/train_network.py +++ b/train_network.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn from torch.types import Number from library.device_utils import init_ipex, clean_memory_on_device +from library.network_utils import maybe_pruned_save, maybe_sample_params init_ipex() @@ -1291,7 +1292,9 @@ class NetworkTrainer: sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) - unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + pruning_enabled = getattr(args, 'enable_pruning', False) + with maybe_pruned_save(unwrapped_nw, optimizer.optimizer, enable_pruning=pruning_enabled, pruning_ratio=0.1): + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -1408,26 +1411,26 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) + with maybe_sample_params(optimizer.optimizer): + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) - - accelerator.backward(loss) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually if args.max_grad_norm != 0.0: @@ -1884,6 +1887,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) + parser.add_argument( + "--enable_pruning", + action="store_true", + help="Enable parameter pruning during model save / モデル保存時にパラメータの剪定を有効にします", + ) return parser