mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
format by black, add ja comment
This commit is contained in:
@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
|
|||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
#CLAMP_QUANTILE = 0.99
|
# CLAMP_QUANTILE = 0.99
|
||||||
#MIN_DIFF = 1e-1
|
# MIN_DIFF = 1e-1
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
|
def svd(
|
||||||
|
model_org=None,
|
||||||
|
model_tuned=None,
|
||||||
|
save_to=None,
|
||||||
|
dim=4,
|
||||||
|
v2=None,
|
||||||
|
sdxl=None,
|
||||||
|
conv_dim=None,
|
||||||
|
v_parameterization=None,
|
||||||
|
device=None,
|
||||||
|
save_precision=None,
|
||||||
|
clamp_quantile=0.99,
|
||||||
|
min_diff=0.01,
|
||||||
|
no_metadata=False,
|
||||||
|
):
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
@@ -39,9 +53,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert v2 != sdxl or (
|
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||||
not v2 and not sdxl
|
|
||||||
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
|
||||||
if v_parameterization is None:
|
if v_parameterization is None:
|
||||||
v_parameterization = v2
|
v_parameterization = v2
|
||||||
|
|
||||||
@@ -199,9 +211,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
|
|||||||
|
|
||||||
if not no_metadata:
|
if not no_metadata:
|
||||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||||
sai_metadata = sai_model_spec.build_metadata(
|
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||||
None, v2, v_parameterization, sdxl, True, False, time.time(), title=title
|
|
||||||
)
|
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||||
@@ -242,7 +252,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
"--save_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -256,13 +270,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--clamp_quantile",
|
"--clamp_quantile",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.99,
|
default=0.99,
|
||||||
help="Quantile clamping value, float, (0-1). Default = 0.99",
|
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min_diff",
|
"--min_diff",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
|
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||||
|
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata",
|
"--no_metadata",
|
||||||
|
|||||||
Reference in New Issue
Block a user