mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support conv2d 3x3 LoRA
This commit is contained in:
@@ -2209,7 +2209,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisable or end of epoch
|
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||||
|
|||||||
@@ -45,8 +45,13 @@ def svd(args):
|
|||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||||
|
|
||||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
if args.conv_dim is None:
|
||||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
kwargs = {}
|
||||||
|
else:
|
||||||
|
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
||||||
|
|
||||||
|
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
||||||
|
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
||||||
assert len(lora_network_o.text_encoder_loras) == len(
|
assert len(lora_network_o.text_encoder_loras) == len(
|
||||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||||
|
|
||||||
@@ -85,13 +90,27 @@ def svd(args):
|
|||||||
|
|
||||||
# make LoRA with svd
|
# make LoRA with svd
|
||||||
print("calculating by svd")
|
print("calculating by svd")
|
||||||
rank = args.dim
|
|
||||||
lora_weights = {}
|
lora_weights = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_name, mat in tqdm(list(diffs.items())):
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
|
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||||
conv2d = (len(mat.size()) == 4)
|
conv2d = (len(mat.size()) == 4)
|
||||||
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
|
||||||
|
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
||||||
|
out_dim, in_dim = mat.size()[0:2]
|
||||||
|
|
||||||
|
if args.device:
|
||||||
|
mat = mat.to(args.device)
|
||||||
|
# print(mat.size(), mat.device, rank, in_dim, out_dim)
|
||||||
|
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
mat = mat.squeeze()
|
if conv2d_3x3:
|
||||||
|
mat = mat.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
mat = mat.squeeze()
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(mat)
|
U, S, Vh = torch.linalg.svd(mat)
|
||||||
|
|
||||||
@@ -108,6 +127,13 @@ def svd(args):
|
|||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
|
||||||
|
U = U.to("cpu").contiguous()
|
||||||
|
Vh = Vh.to("cpu").contiguous()
|
||||||
|
|
||||||
lora_weights[lora_name] = (U, Vh)
|
lora_weights[lora_name] = (U, Vh)
|
||||||
|
|
||||||
# make state dict for LoRA
|
# make state dict for LoRA
|
||||||
@@ -124,8 +150,8 @@ def svd(args):
|
|||||||
|
|
||||||
weights = lora_weights[lora_name][i]
|
weights = lora_weights[lora_name][i]
|
||||||
# print(key, i, weights.size(), lora_sd[key].size())
|
# print(key, i, weights.size(), lora_sd[key].size())
|
||||||
if len(lora_sd[key].size()) == 4:
|
# if len(lora_sd[key].size()) == 4:
|
||||||
weights = weights.unsqueeze(2).unsqueeze(3)
|
# weights = weights.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||||
lora_sd[key] = weights
|
lora_sd[key] = weights
|
||||||
@@ -139,7 +165,7 @@ def svd(args):
|
|||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
# minimum metadata
|
# minimum metadata
|
||||||
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||||
|
|
||||||
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
@@ -158,6 +184,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--save_to", type=str, default=None,
|
parser.add_argument("--save_to", type=str, default=None,
|
||||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
|
parser.add_argument("--conv_dim", type=int, default=None,
|
||||||
|
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
lora_name = lora_name.replace('.', '_')
|
lora_name = lora_name.replace('.', '_')
|
||||||
name_to_module[lora_name] = child_module
|
name_to_module[lora_name] = child_module
|
||||||
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
weight = module.weight
|
weight = module.weight
|
||||||
|
# print(module_name, down_weight.size(), up_weight.size())
|
||||||
if len(weight.size()) == 2:
|
if len(weight.size()) == 2:
|
||||||
# linear
|
# linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
else:
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
# conv2d
|
# conv2d 1x1
|
||||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||||
).unsqueeze(2).unsqueeze(3) * scale
|
).unsqueeze(2).unsqueeze(3) * scale
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
alphas[lora_module_name] = alpha
|
alphas[lora_module_name] = alpha
|
||||||
if lora_module_name not in base_alphas:
|
if lora_module_name not in base_alphas:
|
||||||
base_alphas[lora_module_name] = alpha
|
base_alphas[lora_module_name] = alpha
|
||||||
|
|
||||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||||
else:
|
else:
|
||||||
merged_sd[key] = lora_sd[key] * scale
|
merged_sd[key] = lora_sd[key] * scale
|
||||||
|
|
||||||
# set alpha to sd
|
# set alpha to sd
|
||||||
for lora_module_name, alpha in base_alphas.items():
|
for lora_module_name, alpha in base_alphas.items():
|
||||||
key = lora_module_name + ".alpha"
|
key = lora_module_name + ".alpha"
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ def save_to_file(file_name, model, state_dict, dtype):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||||
|
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
print(f"loading: {model}")
|
||||||
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|||||||
in_dim = down_weight.size()[1]
|
in_dim = down_weight.size()[1]
|
||||||
out_dim = up_weight.size()[0]
|
out_dim = up_weight.size()[0]
|
||||||
conv2d = len(down_weight.size()) == 4
|
conv2d = len(down_weight.size()) == 4
|
||||||
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||||
|
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||||
|
|
||||||
# make original weight if not exist
|
# make original weight if not exist
|
||||||
if lora_module_name not in merged_sd:
|
if lora_module_name not in merged_sd:
|
||||||
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||||
if device:
|
if device:
|
||||||
weight = weight.to(device)
|
weight = weight.to(device)
|
||||||
else:
|
else:
|
||||||
@@ -77,9 +79,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|||||||
scale = (alpha / network_dim)
|
scale = (alpha / network_dim)
|
||||||
if not conv2d: # linear
|
if not conv2d: # linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
else:
|
elif kernel_size == (1, 1):
|
||||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||||
).unsqueeze(2).unsqueeze(3) * scale
|
).unsqueeze(2).unsqueeze(3) * scale
|
||||||
|
else:
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
merged_sd[lora_module_name] = weight
|
merged_sd[lora_module_name] = weight
|
||||||
|
|
||||||
@@ -89,16 +94,25 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||||
conv2d = (len(mat.size()) == 4)
|
conv2d = (len(mat.size()) == 4)
|
||||||
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = mat.size()[0:2]
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
mat = mat.squeeze()
|
if conv2d_3x3:
|
||||||
|
mat = mat.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
mat = mat.squeeze()
|
||||||
|
|
||||||
|
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(mat)
|
U, S, Vh = torch.linalg.svd(mat)
|
||||||
|
|
||||||
U = U[:, :new_rank]
|
U = U[:, :module_new_rank]
|
||||||
S = S[:new_rank]
|
S = S[:module_new_rank]
|
||||||
U = U @ torch.diag(S)
|
U = U @ torch.diag(S)
|
||||||
|
|
||||||
Vh = Vh[:new_rank, :]
|
Vh = Vh[:module_new_rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
@@ -107,16 +121,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
|||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
|
||||||
up_weight = U
|
up_weight = U
|
||||||
down_weight = Vh
|
down_weight = Vh
|
||||||
|
|
||||||
if conv2d:
|
|
||||||
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
|
||||||
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
|
||||||
|
|
||||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
||||||
|
|
||||||
return merged_lora_sd
|
return merged_lora_sd
|
||||||
|
|
||||||
@@ -138,7 +152,8 @@ def merge(args):
|
|||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||||
|
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
@@ -158,6 +173,8 @@ if __name__ == '__main__':
|
|||||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
parser.add_argument("--new_rank", type=int, default=4,
|
parser.add_argument("--new_rank", type=int, default=4,
|
||||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||||
|
parser.add_argument("--new_conv_rank", type=int, default=None,
|
||||||
|
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
|||||||
|
|
||||||
### よく使われるオプションについて
|
### よく使われるオプションについて
|
||||||
|
|
||||||
以下の場合にはオプションに関するドキュメントを参照してください。
|
以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。
|
||||||
|
|
||||||
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
||||||
- clip skipを2以上を前提としたモデルを学習する
|
- clip skipを2以上を前提としたモデルを学習する
|
||||||
|
|||||||
@@ -1,118 +1,99 @@
|
|||||||
## LoRAの学習について
|
# LoRAの学習について
|
||||||
|
|
||||||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
|
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
|
||||||
|
|
||||||
[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
|
[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
|
||||||
|
|
||||||
|
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||||
|
|
||||||
|
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
||||||
|
|
||||||
8GB VRAMでもぎりぎり動作するようです。
|
8GB VRAMでもぎりぎり動作するようです。
|
||||||
|
|
||||||
|
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
||||||
|
|
||||||
## 学習したモデルに関する注意
|
## 学習したモデルに関する注意
|
||||||
|
|
||||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||||
|
|
||||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||||
|
|
||||||
## 学習方法
|
# 学習の手順
|
||||||
|
|
||||||
train_network.pyを用います。
|
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
||||||
|
|
||||||
DreamBoothの手法(identifier(sksなど)とclass、オプションで正則化画像を用いる)と、キャプションを用いるfine tuningの手法の両方で学習できます。
|
## データの準備
|
||||||
|
|
||||||
どちらの方法も既存のスクリプトとほぼ同じ方法で学習できます。異なる点については後述します。
|
[学習データの準備について](./train_README-ja.md) を参照してください。
|
||||||
|
|
||||||
### DreamBoothの手法を用いる場合
|
|
||||||
|
|
||||||
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
## 学習の実行
|
||||||
|
|
||||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
`train_network.py`を用います。
|
||||||
|
|
||||||
ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。
|
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
||||||
|
|
||||||
### キャプションを用いる場合
|
|
||||||
|
|
||||||
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
|
|
||||||
|
|
||||||
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
|
||||||
|
|
||||||
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
|
|
||||||
|
|
||||||
### LoRAの学習のためのオプション
|
|
||||||
|
|
||||||
train_network.pyでは--network_moduleオプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
|
||||||
|
|
||||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
||||||
|
|
||||||
以下はコマンドラインの例です(DreamBooth手法)。
|
以下はコマンドラインの例です。
|
||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||||
--pretrained_model_name_or_path=..\models\model.ckpt
|
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
||||||
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
--dataset_config=<データ準備で作成した.tomlファイル>
|
||||||
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
--output_dir=<学習したモデルの出力先フォルダ>
|
||||||
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
--output_name=<学習したモデル出力時のファイル名>
|
||||||
--max_train_steps=400 --optimizer_type=AdamW8bit --xformers --mixed_precision=fp16
|
--save_model_as=safetensors
|
||||||
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
--prior_loss_weight=1.0
|
||||||
|
--max_train_steps=400
|
||||||
|
--learning_rate=1e-4
|
||||||
|
--optimizer_type="AdamW8bit"
|
||||||
|
--xformers
|
||||||
|
--mixed_precision="fp16"
|
||||||
|
--cache_latents
|
||||||
|
--gradient_checkpointing
|
||||||
|
--save_every_n_epochs=1
|
||||||
--network_module=networks.lora
|
--network_module=networks.lora
|
||||||
```
|
```
|
||||||
|
|
||||||
(2023/2/22:オプティマイザの指定方法が変わりました。[こちら](#オプティマイザの指定について)をご覧ください。)
|
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
||||||
|
|
||||||
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
|
|
||||||
|
|
||||||
その他、以下のオプションが指定できます。
|
その他、以下のオプションが指定できます。
|
||||||
|
|
||||||
* --network_dim
|
* `--network_dim`
|
||||||
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||||
* --network_alpha
|
* `--network_alpha`
|
||||||
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||||
* --network_weights
|
* `--network_weights`
|
||||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||||
* --network_train_unet_only
|
* `--network_train_unet_only`
|
||||||
* U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
|
* U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
|
||||||
* --network_train_text_encoder_only
|
* `--network_train_text_encoder_only`
|
||||||
* Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
|
* Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
|
||||||
* --unet_lr
|
* `--unet_lr`
|
||||||
* U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
|
* U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
|
||||||
* --text_encoder_lr
|
* `--text_encoder_lr`
|
||||||
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
||||||
|
* `--network_args`
|
||||||
|
* 複数の引数を指定できます。後述します。
|
||||||
|
|
||||||
--network_train_unet_onlyと--network_train_text_encoder_onlyの両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||||
|
|
||||||
## オプティマイザの指定について
|
## LoRA を Conv2d に拡大して適用する
|
||||||
|
|
||||||
--optimizer_type オプションでオプティマイザの種類を指定します。以下が指定できます。
|
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||||
|
|
||||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
||||||
- 過去のバージョンのオプション未指定時と同じ
|
|
||||||
- AdamW8bit : 引数は同上
|
|
||||||
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
|
||||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
|
||||||
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
|
||||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
|
||||||
- SGDNesterov8bit : 引数は同上
|
|
||||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
|
||||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
|
||||||
- 任意のオプティマイザ
|
|
||||||
|
|
||||||
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
```
|
||||||
|
--network_args "conv_dim=1" "conv_alpha=1"
|
||||||
|
```
|
||||||
|
|
||||||
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
|
以下のように alpha 省略時は1になります。
|
||||||
|
|
||||||
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
|
```
|
||||||
|
--network_args "conv_dim=1"
|
||||||
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
|
```
|
||||||
|
|
||||||
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
|
|
||||||
|
|
||||||
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
|
|
||||||
|
|
||||||
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
|
|
||||||
|
|
||||||
### 任意のオプティマイザを使う
|
|
||||||
|
|
||||||
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
|
|
||||||
|
|
||||||
(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
|
|
||||||
|
|
||||||
## マージスクリプトについて
|
## マージスクリプトについて
|
||||||
|
|
||||||
@@ -176,6 +157,27 @@ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``
|
|||||||
* save_precision
|
* save_precision
|
||||||
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
||||||
|
|
||||||
|
|
||||||
|
## 複数のrankが異なるLoRAのモデルをマージする
|
||||||
|
|
||||||
|
複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。
|
||||||
|
|
||||||
|
```
|
||||||
|
python networks\svd_merge_lora.py
|
||||||
|
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||||
|
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
||||||
|
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
`merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。
|
||||||
|
|
||||||
|
- `--new_rank`
|
||||||
|
- 作成するLoRAのrankを指定します。
|
||||||
|
- `--new_conv_rank`
|
||||||
|
- 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。
|
||||||
|
- `--device`
|
||||||
|
- `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。
|
||||||
|
|
||||||
## 当リポジトリ内の画像生成スクリプトで生成する
|
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||||
|
|
||||||
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
||||||
@@ -209,12 +211,14 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
|
|||||||
|
|
||||||
### その他のオプション
|
### その他のオプション
|
||||||
|
|
||||||
- --v2
|
- `--v2`
|
||||||
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
||||||
- --device
|
- `--device`
|
||||||
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
||||||
- --save_precision
|
- `--save_precision`
|
||||||
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
||||||
|
- `--conv_dim`
|
||||||
|
- 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。
|
||||||
|
|
||||||
## 画像リサイズスクリプト
|
## 画像リサイズスクリプト
|
||||||
|
|
||||||
@@ -252,7 +256,7 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256
|
|||||||
|
|
||||||
### cloneofsimo氏のリポジトリとの違い
|
### cloneofsimo氏のリポジトリとの違い
|
||||||
|
|
||||||
12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
||||||
|
|
||||||
またモジュール入れ替え機構は全く異なります。
|
またモジュール入れ替え機構は全く異なります。
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user