mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Fix GGPO variables. Fix no _org lora values.
- Add pythonpath = . to pytest to get the current directory - Fix device of LoRA after PiSSA initialization to return to proper device
This commit is contained in:
@@ -104,8 +104,8 @@ def initialize_pissa(
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
|
||||
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(dtype=lora_down.weight.dtype)
|
||||
lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype)
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
@@ -86,6 +87,19 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self._org_lora_up = None
|
||||
self._org_lora_down = None
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
|
||||
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
|
||||
"""
|
||||
@@ -130,15 +144,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self._org_lora_up = self._org_lora_up.to("cpu")
|
||||
self._org_lora_down = self._org_lora_down.to("cpu")
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.lora_util import initialize_pissa
|
||||
from library.network_utils import initialize_pissa
|
||||
from library.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ def test_alpha_scaling():
|
||||
|
||||
|
||||
def test_initialization_methods():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Test different initialization methods
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
@@ -73,7 +74,8 @@ def test_initialization_methods():
|
||||
assert lora_module1.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# URAE initialization
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4)
|
||||
lora_module2.initialize_weights(org_module, "urae", device)
|
||||
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
|
||||
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
|
||||
|
||||
@@ -81,7 +83,8 @@ def test_initialization_methods():
|
||||
assert lora_module2.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# PISSA initialization
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4)
|
||||
lora_module3.initialize_weights(org_module, "pissa", device)
|
||||
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
|
||||
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user