Fix GGPO variables. Fix no _org lora values.

- Add pythonpath = . to pytest to get the current directory
- Fix device of LoRA after PiSSA initialization to return to proper
  device
This commit is contained in:
rockerBOO
2025-04-10 21:59:39 -04:00
parent 7dd00204eb
commit adb0e54093
5 changed files with 23 additions and 14 deletions

View File

@@ -104,8 +104,8 @@ def initialize_pissa(
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(dtype=lora_down.weight.dtype)
lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype)
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)

View File

@@ -7,6 +7,7 @@
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Type, Union
from diffusers import AutoencoderKL
@@ -86,6 +87,19 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self._org_lora_up = None
self._org_lora_down = None
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
"""
@@ -130,15 +144,6 @@ class LoRAModule(torch.nn.Module):
self._org_lora_up = self._org_lora_up.to("cpu")
self._org_lora_down = self._org_lora_down.to("cpu")
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -1,6 +1,6 @@
import torch
import pytest
from library.lora_util import initialize_pissa
from library.network_utils import initialize_pissa
from library.test_util import generate_synthetic_weights

View File

@@ -62,6 +62,7 @@ def test_alpha_scaling():
def test_initialization_methods():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different initialization methods
org_module = nn.Linear(10, 20)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
@@ -73,7 +74,8 @@ def test_initialization_methods():
assert lora_module1.lora_up.weight.shape == (20, 4)
# URAE initialization
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4)
lora_module2.initialize_weights(org_module, "urae", device)
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
@@ -81,7 +83,8 @@ def test_initialization_methods():
assert lora_module2.lora_up.weight.shape == (20, 4)
# PISSA initialization
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4)
lora_module3.initialize_weights(org_module, "pissa", device)
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None