mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Fix LoRA dtype when saving PiSSA
Change lora_util to network_utils to match terms.
This commit is contained in:
@@ -105,7 +105,7 @@ def initialize_pissa(
|
||||
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
|
||||
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(dtype=lora_down.weight.dtype)
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
@@ -17,7 +17,7 @@ from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -784,6 +784,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.train_double_block_indices:
|
||||
logger.info(f"train_double_block_indices={self.train_double_block_indices}")
|
||||
|
||||
if self.train_single_block_indices:
|
||||
logger.info(f"train_single_block_indices={self.train_single_block_indices}")
|
||||
|
||||
if self.initialize:
|
||||
logger.info(f"initialization={self.initialize}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info("split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
@@ -1318,10 +1327,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||
lora_up = state_dict[lora_up_key]
|
||||
lora_down = state_dict[lora_down_key]
|
||||
with torch.autocast("cuda"):
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
# TODO: Capture option if we should offload
|
||||
# offload to CPU
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
# TODO: Capture option if we should offload
|
||||
# offload to CPU
|
||||
state_dict[lora_up_key] = up.detach()
|
||||
state_dict[lora_down_key] = down.detach()
|
||||
progress.update(1)
|
||||
|
||||
Reference in New Issue
Block a user