From 9d7e2dd7c9690e345ae53b4ab9d0d90d303b96c1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 10 Apr 2025 20:59:47 -0400 Subject: [PATCH] Fix LoRA dtype when saving PiSSA Change lora_util to network_utils to match terms. --- library/{lora_util.py => network_utils.py} | 2 +- networks/lora_flux.py | 18 +++++++++++++----- ...test_lora_util.py => test_network_utils.py} | 0 3 files changed, 14 insertions(+), 6 deletions(-) rename library/{lora_util.py => network_utils.py} (98%) rename tests/library/{test_lora_util.py => test_network_utils.py} (100%) diff --git a/library/lora_util.py b/library/network_utils.py similarity index 98% rename from library/lora_util.py rename to library/network_utils.py index 461dc13a..c6d03e55 100644 --- a/library/lora_util.py +++ b/library/network_utils.py @@ -105,7 +105,7 @@ def initialize_pissa( warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) lora_up.weight.data = up.to(dtype=lora_up.weight.dtype) - lora_down.weight.data = down.to(dtype=lora_up.weight.dtype) + lora_down.weight.data = down.to(dtype=lora_down.weight.dtype) weight = weight.data - scale * (up @ down) org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 719451a8..3beb1dfb 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -17,7 +17,7 @@ from tqdm import tqdm import re from library.utils import setup_logging from library.device_utils import clean_memory_on_device -from library.lora_util import initialize_lora, initialize_pissa, initialize_urae +from library.network_utils import initialize_lora, initialize_pissa, initialize_urae setup_logging() import logging @@ -784,6 +784,15 @@ class LoRANetwork(torch.nn.Module): if ggpo_beta is not None and ggpo_sigma is not None: logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + if self.train_double_block_indices: + logger.info(f"train_double_block_indices={self.train_double_block_indices}") + + if self.train_single_block_indices: + logger.info(f"train_single_block_indices={self.train_single_block_indices}") + + if self.initialize: + logger.info(f"initialization={self.initialize}") + if self.split_qkv: logger.info("split qkv for LoRA") if self.train_blocks is not None: @@ -1318,10 +1327,9 @@ class LoRANetwork(torch.nn.Module): lora_down_key = f"{lora.lora_name}.lora_down.weight" lora_up = state_dict[lora_up_key] lora_down = state_dict[lora_down_key] - with torch.autocast("cuda"): - up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) - # TODO: Capture option if we should offload - # offload to CPU + up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + # TODO: Capture option if we should offload + # offload to CPU state_dict[lora_up_key] = up.detach() state_dict[lora_down_key] = down.detach() progress.update(1) diff --git a/tests/library/test_lora_util.py b/tests/library/test_network_utils.py similarity index 100% rename from tests/library/test_lora_util.py rename to tests/library/test_network_utils.py