format by black, add ja comment

This commit is contained in:
Kohya S
2023-11-25 21:05:55 +09:00
parent 97958400fb
commit 0fb9ecf1f3

View File

@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
@@ -39,9 +53,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
return torch.bfloat16
return None
assert v2 != sdxl or (
not v2 and not sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if v_parameterization is None:
v_parameterization = v2
@@ -199,9 +211,7 @@ def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=Non
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, v2, v_parameterization, sdxl, True, False, time.time(), title=title
)
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
lora_network_save.save_weights(save_to, save_dtype, metadata)
@@ -242,7 +252,11 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument(
@@ -256,13 +270,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99",
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument(
"--no_metadata",