mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add min_diff, clamp_quantile args
based on https://github.com/bmaltais/kohya_ss/pull/1332 a9ec90c40a
This commit is contained in:
@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
|
|||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
#CLAMP_QUANTILE = 0.99
|
||||||
MIN_DIFF = 1e-1
|
#MIN_DIFF = 1e-1
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
@@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False):
|
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
@@ -91,9 +91,9 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
|
|||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight - module_o.weight
|
||||||
|
|
||||||
# Text Encoder might be same
|
# Text Encoder might be same
|
||||||
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
|
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||||
text_encoder_different = True
|
text_encoder_different = True
|
||||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
|
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||||
|
|
||||||
diff = diff.float()
|
diff = diff.float()
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
@@ -149,7 +149,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
|
|||||||
Vh = Vh[:rank, :]
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
hi_val = torch.quantile(dist, clamp_quantile)
|
||||||
low_val = -hi_val
|
low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
@@ -252,6 +252,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
parser.add_argument(
|
||||||
|
"--clamp_quantile",
|
||||||
|
type=float,
|
||||||
|
default=0.99,
|
||||||
|
help="Quantile clamping value, float, (0-1). Default = 0.99",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_diff",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata",
|
"--no_metadata",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user