mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add comment about scaling
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
# Thanks to cloneofsimo
|
# Thanks to cloneofsimo
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
@@ -43,6 +44,7 @@ def split_lora_model(lora_sd, unit):
|
|||||||
|
|
||||||
rank = unit
|
rank = unit
|
||||||
splitted_models = []
|
splitted_models = []
|
||||||
|
new_alpha = None
|
||||||
while rank < max_rank:
|
while rank < max_rank:
|
||||||
print(f"Splitting rank {rank}")
|
print(f"Splitting rank {rank}")
|
||||||
new_sd = {}
|
new_sd = {}
|
||||||
@@ -52,9 +54,15 @@ def split_lora_model(lora_sd, unit):
|
|||||||
elif "lora_up" in key:
|
elif "lora_up" in key:
|
||||||
new_sd[key] = value[:, :rank].contiguous()
|
new_sd[key] = value[:, :rank].contiguous()
|
||||||
else:
|
else:
|
||||||
new_sd[key] = value # alpha and other parameters
|
# なぜかscaleするとおかしくなる……
|
||||||
|
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||||
|
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||||
|
# print(key, value.size(), this_rank, rank, value, scale)
|
||||||
|
# new_alpha = value * scale # always same
|
||||||
|
# new_sd[key] = new_alpha
|
||||||
|
new_sd[key] = value
|
||||||
|
|
||||||
splitted_models.append((new_sd, rank))
|
splitted_models.append((new_sd, rank, new_alpha))
|
||||||
rank += unit
|
rank += unit
|
||||||
|
|
||||||
return max_rank, splitted_models
|
return max_rank, splitted_models
|
||||||
@@ -68,7 +76,7 @@ def split(args):
|
|||||||
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
||||||
|
|
||||||
comment = metadata.get("ss_training_comment", "")
|
comment = metadata.get("ss_training_comment", "")
|
||||||
for state_dict, new_rank in splitted_models:
|
for state_dict, new_rank, new_alpha in splitted_models:
|
||||||
# update metadata
|
# update metadata
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
new_metadata = {}
|
new_metadata = {}
|
||||||
@@ -77,6 +85,7 @@ def split(args):
|
|||||||
|
|
||||||
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
||||||
new_metadata["ss_network_dim"] = str(new_rank)
|
new_metadata["ss_network_dim"] = str(new_rank)
|
||||||
|
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
|||||||
Reference in New Issue
Block a user