Fix LoRA dtype when saving PiSSA

Change lora_util to network_utils to match terms.
This commit is contained in:
rockerBOO
2025-04-10 20:59:47 -04:00
parent 5f927444d0
commit 9d7e2dd7c9
3 changed files with 14 additions and 6 deletions

View File

@@ -105,7 +105,7 @@ def initialize_pissa(
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(dtype=lora_down.weight.dtype)
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)

View File

@@ -17,7 +17,7 @@ from tqdm import tqdm
import re
from library.utils import setup_logging
from library.device_utils import clean_memory_on_device
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae
setup_logging()
import logging
@@ -784,6 +784,15 @@ class LoRANetwork(torch.nn.Module):
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.train_double_block_indices:
logger.info(f"train_double_block_indices={self.train_double_block_indices}")
if self.train_single_block_indices:
logger.info(f"train_single_block_indices={self.train_single_block_indices}")
if self.initialize:
logger.info(f"initialization={self.initialize}")
if self.split_qkv:
logger.info("split qkv for LoRA")
if self.train_blocks is not None:
@@ -1318,10 +1327,9 @@ class LoRANetwork(torch.nn.Module):
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
with torch.autocast("cuda"):
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
# TODO: Capture option if we should offload
# offload to CPU
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
# TODO: Capture option if we should offload
# offload to CPU
state_dict[lora_up_key] = up.detach()
state_dict[lora_down_key] = down.detach()
progress.update(1)