add min_diff, clamp_quantile args

based on https://github.com/bmaltais/kohya_ss/pull/1332 a9ec90c40a
This commit is contained in:
Won-Kyu Park
2023-11-08 19:35:10 +09:00
parent e20e9f61ac
commit 2c1e669bd8

View File

@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
import lora import lora
CLAMP_QUANTILE = 0.99 #CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1 #MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
@@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name) torch.save(model, file_name)
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False): def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
return torch.float return torch.float
@@ -91,9 +91,9 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
diff = module_t.weight - module_o.weight diff = module_t.weight - module_o.weight
# Text Encoder might be same # Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float() diff = diff.float()
diffs[lora_name] = diff diffs[lora_name] = diff
@@ -149,7 +149,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
Vh = Vh[:rank, :] Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val low_val = -hi_val
U = U.clamp(low_val, hi_val) U = U.clamp(low_val, hi_val)
@@ -252,6 +252,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし", help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
) )
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
)
parser.add_argument( parser.add_argument(
"--no_metadata", "--no_metadata",
action="store_true", action="store_true",