format by black

This commit is contained in:
Kohya S
2023-07-24 21:28:37 +09:00
parent b1e44e96bc
commit e83ee217d3

View File

@@ -21,7 +21,7 @@ def save_to_file(file_name, model, state_dict, dtype):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name) save_file(model, file_name)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
@@ -29,11 +29,11 @@ def save_to_file(file_name, model, state_dict, dtype):
def svd(args): def svd(args):
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == "float":
return torch.float return torch.float
if p == 'fp16': if p == "fp16":
return torch.float16 return torch.float16
if p == 'bf16': if p == "bf16":
return torch.bfloat16 return torch.bfloat16
return None return None
@@ -53,7 +53,8 @@ def svd(args):
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
# get diffs # get diffs
diffs = {} diffs = {}
@@ -94,7 +95,7 @@ def svd(args):
with torch.no_grad(): with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())): for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = (len(mat.size()) == 4) conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1) conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -140,9 +141,9 @@ def svd(args):
# make state dict for LoRA # make state dict for LoRA
lora_sd = {} lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items(): for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + '.lora_up.weight'] = up_weight lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + '.lora_down.weight'] = down_weight lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
# load state dict to LoRA and save it # load state dict to LoRA and save it
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
@@ -164,25 +165,42 @@ def svd(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') parser.add_argument(
parser.add_argument("--save_precision", type=str, default=None, "--save_precision",
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") type=str,
parser.add_argument("--model_org", type=str, default=None, default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") choices=[None, "float", "fp16", "bf16"],
parser.add_argument("--model_tuned", type=str, default=None, help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors") )
parser.add_argument("--save_to", type=str, default=None, parser.add_argument(
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") "--model_org",
type=str,
default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument("--conv_dim", type=int, default=None, parser.add_argument(
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし") "--conv_dim",
type=int,
default=None,
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()