mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add comment about scaling
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
@@ -43,6 +44,7 @@ def split_lora_model(lora_sd, unit):
|
||||
|
||||
rank = unit
|
||||
splitted_models = []
|
||||
new_alpha = None
|
||||
while rank < max_rank:
|
||||
print(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
@@ -52,9 +54,15 @@ def split_lora_model(lora_sd, unit):
|
||||
elif "lora_up" in key:
|
||||
new_sd[key] = value[:, :rank].contiguous()
|
||||
else:
|
||||
new_sd[key] = value # alpha and other parameters
|
||||
# なぜかscaleするとおかしくなる……
|
||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||
# print(key, value.size(), this_rank, rank, value, scale)
|
||||
# new_alpha = value * scale # always same
|
||||
# new_sd[key] = new_alpha
|
||||
new_sd[key] = value
|
||||
|
||||
splitted_models.append((new_sd, rank))
|
||||
splitted_models.append((new_sd, rank, new_alpha))
|
||||
rank += unit
|
||||
|
||||
return max_rank, splitted_models
|
||||
@@ -68,7 +76,7 @@ def split(args):
|
||||
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
for state_dict, new_rank in splitted_models:
|
||||
for state_dict, new_rank, new_alpha in splitted_models:
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
new_metadata = {}
|
||||
@@ -77,6 +85,7 @@ def split(args):
|
||||
|
||||
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
||||
new_metadata["ss_network_dim"] = str(new_rank)
|
||||
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
|
||||
Reference in New Issue
Block a user