Compare commits

...

35 Commits

Author SHA1 Message Date
rockerBOO
5391c4fbe9 Fix typo 2025-04-10 22:07:46 -04:00
rockerBOO
adb0e54093 Fix GGPO variables. Fix no _org lora values.
- Add pythonpath = . to pytest to get the current directory
- Fix device of LoRA after PiSSA initialization to return to proper
  device
2025-04-10 21:59:39 -04:00
rockerBOO
7dd00204eb Revert "Merge branch 'sd3' into flux-lora-init"
This reverts commit 67f8e17a46, reversing
changes made to 9d7e2dd7c9.
2025-04-10 21:38:05 -04:00
rockerBOO
67f8e17a46 Merge branch 'sd3' into flux-lora-init 2025-04-10 21:15:22 -04:00
rockerBOO
9d7e2dd7c9 Fix LoRA dtype when saving PiSSA
Change lora_util to network_utils to match terms.
2025-04-10 20:59:47 -04:00
Kohya S
5a18a03ffc Merge branch 'dev' into sd3 2025-04-07 21:55:17 +09:00
Kohya S
572cc3efb8 Merge branch 'main' into dev 2025-04-07 21:48:45 +09:00
Kohya S.
52c8dec953 Merge pull request #2015 from DKnight54/uncache_vae_batch
Using --vae_batch_size to set batch size for dynamic latent generation
2025-04-07 21:48:02 +09:00
Kohya S
4589262f8f README.md: Update recent updates section to include IP noise gamma feature for FLUX.1 2025-04-06 21:34:27 +09:00
Kohya S.
c56dc90b26 Merge pull request #1992 from rockerBOO/flux-ip-noise-gamma
Add IP noise gamma for Flux
2025-04-06 21:29:26 +09:00
Kohya S.
ee0f754b08 Merge pull request #2028 from rockerBOO/patch-5
Fix resize PR link
2025-04-05 20:15:13 +09:00
Kohya S.
606e6875d2 Merge pull request #2022 from LexSong/fix-resize-issue
Fix size parameter types and improve resize_image interpolation
2025-04-05 19:28:25 +09:00
Dave Lage
fd36fd1aa9 Fix resize PR link 2025-04-03 16:09:45 -04:00
Kohya S.
92845e8806 Merge pull request #2026 from kohya-ss/fix-finetune-dataset-resize-interpolation
fix: add resize_interpolation parameter to FineTuningDataset constructor
2025-04-03 21:52:14 +09:00
Kohya S
f1423a7229 fix: add resize_interpolation parameter to FineTuningDataset constructor 2025-04-03 21:48:51 +09:00
Lex Song
b822b7e60b Fix the interpolation logic error in resize_image()
The original code had a mistake. It used 'lanczos' when the image got smaller (width > resized_width and height > resized_height) and 'area' when it stayed the same or got bigger. This was the wrong way. 'area' is better for big shrinking.
2025-04-02 22:04:37 +08:00
Lex Song
ede3470260 Ensure all size parameters are integers to prevent type errors 2025-04-02 03:50:33 +08:00
DKnight54
381303d64f Update train_network.py 2025-03-29 02:26:18 +08:00
rockerBOO
89f0d27a59 Set sigmoid_scale to default 1.0 2025-03-20 15:10:33 -04:00
rockerBOO
d40f5b1e4e Revert "Scale sigmoid to default 1.0"
This reverts commit 8aa126582e.
2025-03-20 15:09:50 -04:00
rockerBOO
8aa126582e Scale sigmoid to default 1.0 2025-03-20 15:09:11 -04:00
rockerBOO
e8b3254858 Add flux_train_utils tests for get get_noisy_model_input_and_timesteps 2025-03-20 15:01:15 -04:00
rockerBOO
16cef81aea Refactor sigmas and timesteps 2025-03-20 14:32:56 -04:00
rockerBOO
f974c6b257 change order to match upstream 2025-03-19 14:27:43 -04:00
rockerBOO
5d5a7d2acf Fix IP noise calculation 2025-03-19 13:50:04 -04:00
rockerBOO
1eddac26b0 Separate random to a variable, and make sure on device 2025-03-19 00:49:42 -04:00
rockerBOO
8e6817b0c2 Remove double noise 2025-03-19 00:45:13 -04:00
rockerBOO
d93ad90a71 Add perturbation on noisy_model_input if needed 2025-03-19 00:37:27 -04:00
rockerBOO
7197266703 Perturbed noise should be separate of input noise 2025-03-19 00:25:51 -04:00
rockerBOO
b81bcd0b01 Move IP noise gamma to noise creation to remove complexity and align noise for target loss 2025-03-18 21:36:55 -04:00
rockerBOO
6f4d365775 zeros_like because we are adding 2025-03-18 18:53:34 -04:00
rockerBOO
a4f3a9fc1a Use ones_like 2025-03-18 18:44:21 -04:00
rockerBOO
b425466e7b Fix IP noise gamma to use random values 2025-03-18 18:42:35 -04:00
rockerBOO
c8be141ae0 Apply IP gamma to noise fix 2025-03-18 15:42:18 -04:00
rockerBOO
0b25a05e3c Add IP noise gamma for Flux 2025-03-18 15:40:40 -04:00
5 changed files with 37 additions and 20 deletions

View File

@@ -104,8 +104,8 @@ def initialize_pissa(
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype)
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)

View File

@@ -7,6 +7,7 @@
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Type, Union
from diffusers import AutoencoderKL
@@ -17,7 +18,7 @@ from tqdm import tqdm
import re
from library.utils import setup_logging
from library.device_utils import clean_memory_on_device
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae
setup_logging()
import logging
@@ -86,10 +87,23 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self._org_lora_up = None
self._org_lora_down = None
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
"""
Inititalize the weights for the LoRA
Initialize the weights for the LoRA
org_module: original module we are applying the LoRA to
device: device to run initialization computation on
@@ -130,15 +144,6 @@ class LoRAModule(torch.nn.Module):
self._org_lora_up = self._org_lora_up.to("cpu")
self._org_lora_down = self._org_lora_down.to("cpu")
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -784,6 +789,15 @@ class LoRANetwork(torch.nn.Module):
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.train_double_block_indices:
logger.info(f"train_double_block_indices={self.train_double_block_indices}")
if self.train_single_block_indices:
logger.info(f"train_single_block_indices={self.train_single_block_indices}")
if self.initialize:
logger.info(f"initialization={self.initialize}")
if self.split_qkv:
logger.info("split qkv for LoRA")
if self.train_blocks is not None:
@@ -1318,10 +1332,9 @@ class LoRANetwork(torch.nn.Module):
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
with torch.autocast("cuda"):
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
# TODO: Capture option if we should offload
# offload to CPU
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
# TODO: Capture option if we should offload
# offload to CPU
state_dict[lora_up_key] = up.detach()
state_dict[lora_down_key] = down.detach()
progress.update(1)

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -1,6 +1,6 @@
import torch
import pytest
from library.lora_util import initialize_pissa
from library.network_utils import initialize_pissa
from library.test_util import generate_synthetic_weights

View File

@@ -62,6 +62,7 @@ def test_alpha_scaling():
def test_initialization_methods():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test different initialization methods
org_module = nn.Linear(10, 20)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
@@ -73,7 +74,8 @@ def test_initialization_methods():
assert lora_module1.lora_up.weight.shape == (20, 4)
# URAE initialization
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4)
lora_module2.initialize_weights(org_module, "urae", device)
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
@@ -81,7 +83,8 @@ def test_initialization_methods():
assert lora_module2.lora_up.weight.shape == (20, 4)
# PISSA initialization
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4)
lora_module3.initialize_weights(org_module, "pissa", device)
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None