mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
35 Commits
5f927444d0
...
5391c4fbe9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5391c4fbe9 | ||
|
|
adb0e54093 | ||
|
|
7dd00204eb | ||
|
|
67f8e17a46 | ||
|
|
9d7e2dd7c9 | ||
|
|
5a18a03ffc | ||
|
|
572cc3efb8 | ||
|
|
52c8dec953 | ||
|
|
4589262f8f | ||
|
|
c56dc90b26 | ||
|
|
ee0f754b08 | ||
|
|
606e6875d2 | ||
|
|
fd36fd1aa9 | ||
|
|
92845e8806 | ||
|
|
f1423a7229 | ||
|
|
b822b7e60b | ||
|
|
ede3470260 | ||
|
|
381303d64f | ||
|
|
89f0d27a59 | ||
|
|
d40f5b1e4e | ||
|
|
8aa126582e | ||
|
|
e8b3254858 | ||
|
|
16cef81aea | ||
|
|
f974c6b257 | ||
|
|
5d5a7d2acf | ||
|
|
1eddac26b0 | ||
|
|
8e6817b0c2 | ||
|
|
d93ad90a71 | ||
|
|
7197266703 | ||
|
|
b81bcd0b01 | ||
|
|
6f4d365775 | ||
|
|
a4f3a9fc1a | ||
|
|
b425466e7b | ||
|
|
c8be141ae0 | ||
|
|
0b25a05e3c |
@@ -104,8 +104,8 @@ def initialize_pissa(
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
|
||||
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
|
||||
lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype)
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
@@ -7,6 +7,7 @@
|
||||
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
@@ -17,7 +18,7 @@ from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -86,10 +87,23 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self._org_lora_up = None
|
||||
self._org_lora_down = None
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
|
||||
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
|
||||
"""
|
||||
Inititalize the weights for the LoRA
|
||||
Initialize the weights for the LoRA
|
||||
|
||||
org_module: original module we are applying the LoRA to
|
||||
device: device to run initialization computation on
|
||||
@@ -130,15 +144,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self._org_lora_up = self._org_lora_up.to("cpu")
|
||||
self._org_lora_down = self._org_lora_down.to("cpu")
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -784,6 +789,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.train_double_block_indices:
|
||||
logger.info(f"train_double_block_indices={self.train_double_block_indices}")
|
||||
|
||||
if self.train_single_block_indices:
|
||||
logger.info(f"train_single_block_indices={self.train_single_block_indices}")
|
||||
|
||||
if self.initialize:
|
||||
logger.info(f"initialization={self.initialize}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info("split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
@@ -1318,10 +1332,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||
lora_up = state_dict[lora_up_key]
|
||||
lora_down = state_dict[lora_down_key]
|
||||
with torch.autocast("cuda"):
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
# TODO: Capture option if we should offload
|
||||
# offload to CPU
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
# TODO: Capture option if we should offload
|
||||
# offload to CPU
|
||||
state_dict[lora_up_key] = up.detach()
|
||||
state_dict[lora_down_key] = down.detach()
|
||||
progress.update(1)
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.lora_util import initialize_pissa
|
||||
from library.network_utils import initialize_pissa
|
||||
from library.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ def test_alpha_scaling():
|
||||
|
||||
|
||||
def test_initialization_methods():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Test different initialization methods
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
@@ -73,7 +74,8 @@ def test_initialization_methods():
|
||||
assert lora_module1.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# URAE initialization
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4)
|
||||
lora_module2.initialize_weights(org_module, "urae", device)
|
||||
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
|
||||
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
|
||||
|
||||
@@ -81,7 +83,8 @@ def test_initialization_methods():
|
||||
assert lora_module2.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# PISSA initialization
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4)
|
||||
lora_module3.initialize_weights(org_module, "pissa", device)
|
||||
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
|
||||
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user