From 2e9f7b5f9135dd9a970bac863907f63adb53f943 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:10:39 +0900 Subject: [PATCH 01/12] cache latents to disk in dreambooth method --- fine_tune.py | 4 ++- library/train_util.py | 62 ++++++++++++++++++++++++++++------ train_db.py | 4 ++- train_network.py | 4 ++- train_textual_inversion.py | 4 ++- train_textual_inversion_XTI.py | 4 ++- 6 files changed, 67 insertions(+), 15 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 2157de98..47454670 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -142,12 +142,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする training_models = [] if args.gradient_checkpointing: diff --git a/library/train_util.py b/library/train_util.py index 56eef81f..6b398707 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset): def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # ちょっと速くした print("caching latents.") @@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset): if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + + # might be None, but that's ok because check is done in dataset + info.latents_flipped = self.load_latents_from_npz(info, True) if info.latents_flipped is not None: info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue + # check disk cache exists and size of latents + if cache_to_disk: + # TODO: refactor to unify with FineTuningDataset + info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" + if not is_main_process: + continue + + cache_available = False + expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 + if os.path.exists(info.latents_npz): + cached_latents = np.load(info.latents_npz) + if cached_latents["latents"].shape[1:3] == expected_latents_size: + cache_available = True + + if subset.flip_aug: + cache_available = False + if os.path.exists(info.latents_npz_flipped): + cached_latents_flipped = np.load(info.latents_npz_flipped) + if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: + cache_available = True + + if cache_available: + continue + # if last member of batch has different resolution, flush the batch if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: batches.append(batch) @@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0: batches.append(batch) + if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only + return + # iterate batches for batch in tqdm(batches, smoothing=1, total=len(batches)): images = [] @@ -773,14 +803,21 @@ class BaseDataset(torch.utils.data.Dataset): img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): - info.latents = latent + if cache_to_disk: + np.savez(info.latents_npz, latent.float().numpy()) + else: + info.latents = latent if subset.flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") for info, latent in zip(batch, latents): - info.latents_flipped = latent + if cache_to_disk: + np.savez(info.latents_npz_flipped, latent.float().numpy()) + else: + info.latents_flipped = latent def get_image_size(self, image_path): image = Image.open(image_path) @@ -873,10 +910,10 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) # image/latentsを処理する - if image_info.latents is not None: + if image_info.latents is not None: # cache_latents=Trueの場合 latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped image = None - elif image_info.latents_npz is not None: + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) latents = torch.FloatTensor(latents) image = None @@ -1340,10 +1377,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -2144,9 +2181,14 @@ def add_dataset_arguments( parser.add_argument( "--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", + help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--cache_latents_to_disk", + action="store_true", + help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", + ) parser.add_argument( "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" ) @@ -3203,4 +3245,4 @@ class collater_class: # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] diff --git a/train_db.py b/train_db.py index e72dc889..eddf8f68 100644 --- a/train_db.py +++ b/train_db.py @@ -117,12 +117,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 diff --git a/train_network.py b/train_network.py index ef630969..fb3d6130 100644 --- a/train_network.py +++ b/train_network.py @@ -172,12 +172,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # prepare network import sys diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 98639345..88ddebdd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -233,12 +233,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index db46ad1b..d302491e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -267,12 +267,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() From 893c2fc08a84f6d4e77a92d254594eb49f883b27 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:14:09 +0900 Subject: [PATCH 02/12] add DyLoRA (experimental) --- networks/dylora.py | 448 +++++++++++++++++++++++++++ networks/extract_lora_from_dylora.py | 261 ++++++++++++++++ train_network.py | 33 +- 3 files changed, 731 insertions(+), 11 deletions(-) create mode 100644 networks/dylora.py create mode 100644 networks/extract_lora_from_dylora.py diff --git a/networks/dylora.py b/networks/dylora.py new file mode 100644 index 00000000..e588813e --- /dev/null +++ b/networks/dylora.py @@ -0,0 +1,448 @@ +# some codes are copied from: +# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/ + +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# Changes made to the original code: +# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +import math +import os +import random +from typing import List, Tuple, Union +import torch +from torch import nn + + +class DyLoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + # NOTE: support dropout in future + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + self.unit = unit + assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit" + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3) + + if self.is_conv2d and self.is_conv2d_3x3: + kernel_size = org_module.kernel_size + self.stride = org_module.stride + self.padding = org_module.padding + self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size))) + self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1))) + else: + self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim))) + self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim))) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_B) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + result = self.org_forward(x) + + # specify the dynamic rank + trainable_rank = random.randint(0, self.lora_dim - 1) + trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit + + # 一部のパラメータを固定して、残りのパラメータを学習する + + # make lora_A + if trainable_rank > 0: + lora_A_nt1 = [self.lora_A[:trainable_rank].detach()] + else: + lora_A_nt1 = [] + + lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit] + + if trainable_rank < self.lora_dim - self.unit: + lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()] + else: + lora_A_nt2 = [] + + lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0) + + # make lora_B + if trainable_rank > 0: + lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()] + else: + lora_B_nt1 = [] + + lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit] + + if trainable_rank < self.lora_dim - self.unit: + lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()] + else: + lora_B_nt2 = [] + + lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1) + + # print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size()) + + # calculate with lora_A and lora_B + if self.is_conv2d_3x3: + ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding) + ab = torch.nn.functional.conv2d(ab, lora_B) + else: + ab = x + if self.is_conv2d: + ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) + + ab = torch.nn.functional.linear(ab, lora_A) + ab = torch.nn.functional.linear(ab, lora_B) + + if self.is_conv2d: + ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) + + # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) + result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) + + # NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも + return result + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # state dictを通常のLoRAと同じにする + state_dict = super().state_dict(destination, prefix, keep_vars) + + lora_A_weight = state_dict.pop(self.lora_name + ".lora_A") + if self.is_conv2d and not self.is_conv2d_3x3: + lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) + state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight + + lora_B_weight = state_dict.pop(self.lora_name + ".lora_B") + if self.is_conv2d and not self.is_conv2d_3x3: + lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) + state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight + + return state_dict + + def load_state_dict(self, state_dict, strict=True): + # 通常のLoRAと同じstate dictを読み込めるようにする + state_dict = state_dict.copy() + + lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight") + if self.is_conv2d and not self.is_conv2d_3x3: + lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) + state_dict[self.lora_name + ".lora_A"] = lora_A_weight + + lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight") + if self.is_conv2d and not self.is_conv2d_3x3: + lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) + state_dict[self.lora_name + ".lora_B"] = lora_B_weight + + super().load_state_dict(state_dict, strict=strict) + + +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + unit = kwargs.get("unit", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + assert conv_dim == network_dim, "conv_dim must be same as network_dim" + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + if unit is not None: + unit = int(unit) + else: + unit = 1 + + network = DyLoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + apply_to_conv=conv_dim is not None, + unit=unit, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + module_class = DyLoRAModule + + network = DyLoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + return network, weights_sd + + +class DyLoRANetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + apply_to_conv=False, + modules_dim=None, + modules_alpha=None, + unit=1, + module_class=DyLoRAModule, + varbose=False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.apply_to_conv = apply_to_conv + + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + if self.apply_to_conv: + print(f"apply LoRA to Conv2d with kernel size (3,3).") + + # create module instances + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: + prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1 or apply_to_conv: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + continue + + # dropout and fan_in_fan_out is default + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) + loras.append(lora) + return loras + + self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.apply_to_conv: + target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(True, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + """ + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + print(f"weights are merged") + """ + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + pass + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + pass diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py new file mode 100644 index 00000000..1c71f734 --- /dev/null +++ b/networks/extract_lora_from_dylora.py @@ -0,0 +1,261 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +from library import train_util, model_util +import numpy as np + + +def load_state_dict(file_name): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None + + return sd, metadata + + +def save_to_file(file_name, model, metadata): + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) + + +# Indexing functions + + +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + s_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +# Calculate new rank + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method == "sv_ratio": + # Calculate new dim and alpha based off ratio + new_rank = index_sv_ratio(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + else: + new_rank = rank + new_alpha = float(scale * new_rank) + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale * new_rank) + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro / s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank) / s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0] / S[new_rank - 1] + + return param_dict + + +def split_lora_model(lora_sd, unit): + max_rank = 0 + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if "lora_down" in key: + rank = value.size()[0] + if rank > max_rank: + max_rank = rank + print(f"Max rank: {max_rank}") + + rank = unit + splitted_models = [] + while rank < max_rank: + print(f"Splitting rank {rank}") + new_sd = {} + for key, value in lora_sd.items(): + if "lora_down" in key: + new_sd[key] = value[:rank].contiguous() + elif "lora_up" in key: + new_sd[key] = value[:, :rank].contiguous() + else: + new_sd[key] = value # alpha and other parameters + + splitted_models.append((new_sd, rank)) + rank += unit + + return max_rank, splitted_models + + +def split(args): + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model) + + print("Splitting Model...") + original_rank, splitted_models = split_lora_model(lora_sd, args.unit) + + comment = metadata.get("ss_training_comment", "") + for state_dict, new_rank in splitted_models: + # update metadata + if metadata is None: + new_metadata = {} + else: + new_metadata = metadata.copy() + + new_metadata["ss_training_comment"] = f"split from DyLoRA from {original_rank} to {new_rank}; {comment}" + new_metadata["ss_network_dim"] = str(new_rank) + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + filename, ext = os.path.splitext(args.save_to) + model_file_name = filename + f"-{new_rank:04d}{ext}" + + print(f"saving model to: {model_file_name}") + save_to_file(model_file_name, state_dict, new_metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + split(args) diff --git a/train_network.py b/train_network.py index fb3d6130..658138b7 100644 --- a/train_network.py +++ b/train_network.py @@ -197,7 +197,7 @@ def train(args): network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return - + if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -221,7 +221,9 @@ def train(args): try: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) except TypeError: - print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)") + print( + "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) @@ -541,6 +543,12 @@ def train(args): loss_list = [] loss_total = 0.0 del train_dataset_group + + # if hasattr(network, "on_step_start"): + # on_step_start = network.on_step_start + # else: + # on_step_start = lambda *args, **kwargs: None + for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") @@ -553,6 +561,8 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): + # on_step_start(text_encoder, unet) + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -565,16 +575,17 @@ def train(args): with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings(tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -759,4 +770,4 @@ if __name__ == "__main__": args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) \ No newline at end of file + train(args) From e09966024cd84051fabe0c2ae0566f05ebb441d0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:16:47 +0900 Subject: [PATCH 03/12] delete unnecessary lines --- networks/extract_lora_from_dylora.py | 147 +-------------------------- 1 file changed, 1 insertion(+), 146 deletions(-) diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 1c71f734..0037636f 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -30,151 +30,6 @@ def save_to_file(file_name, model, metadata): torch.save(model, file_name) -# Indexing functions - - -def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0) / original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S) - 1)) - - return index - - -def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S) - 1)) - - return index - - -def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv / target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S) - 1)) - - return index - - -# Modified from Kohaku-blueleaf's extract/merge functions -def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): - out_size, in_size, kernel_size, _ = weight.size() - U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - - param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) - lora_rank = param_dict["new_rank"] - - U = U[:, :lora_rank] - S = S[:lora_rank] - U = U @ torch.diag(S) - Vh = Vh[:lora_rank, :] - - param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() - param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() - del U, S, Vh, weight - return param_dict - - -def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): - out_size, in_size = weight.size() - - U, S, Vh = torch.linalg.svd(weight.to(device)) - - param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) - lora_rank = param_dict["new_rank"] - - U = U[:, :lora_rank] - S = S[:lora_rank] - U = U @ torch.diag(S) - Vh = Vh[:lora_rank, :] - - param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() - param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() - del U, S, Vh, weight - return param_dict - - -def merge_conv(lora_down, lora_up, device): - in_rank, in_size, kernel_size, k_ = lora_down.shape - out_size, out_rank, _, _ = lora_up.shape - assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - - lora_down = lora_down.to(device) - lora_up = lora_up.to(device) - - merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) - weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) - del lora_up, lora_down - return weight - - -def merge_linear(lora_down, lora_up, device): - in_rank, in_size = lora_down.shape - out_size, out_rank = lora_up.shape - assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - - lora_down = lora_down.to(device) - lora_up = lora_up.to(device) - - weight = lora_up @ lora_down - del lora_up, lora_down - return weight - - -# Calculate new rank - - -def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): - param_dict = {} - - if dynamic_method == "sv_ratio": - # Calculate new dim and alpha based off ratio - new_rank = index_sv_ratio(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - - elif dynamic_method == "sv_cumulative": - # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - - elif dynamic_method == "sv_fro": - # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - else: - new_rank = rank - new_alpha = float(scale * new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 - new_rank = 1 - new_alpha = float(scale * new_rank) - elif new_rank > rank: # cap max rank at rank - new_rank = rank - new_alpha = float(scale * new_rank) - - # Calculate resize info - s_sum = torch.sum(torch.abs(S)) - s_rank = torch.sum(torch.abs(S[:new_rank])) - - S_squared = S.pow(2) - s_fro = torch.sqrt(torch.sum(S_squared)) - s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro / s_fro) - - param_dict["new_rank"] = new_rank - param_dict["new_alpha"] = new_alpha - param_dict["sum_retained"] = (s_rank) / s_sum - param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0] / S[new_rank - 1] - - return param_dict - - def split_lora_model(lora_sd, unit): max_rank = 0 @@ -220,7 +75,7 @@ def split(args): else: new_metadata = metadata.copy() - new_metadata["ss_training_comment"] = f"split from DyLoRA from {original_rank} to {new_rank}; {comment}" + new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" new_metadata["ss_network_dim"] = str(new_rank) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) From 68e0767404d88911079a052a2c40762f98ae58e1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:40:10 +0900 Subject: [PATCH 04/12] add comment about scaling --- networks/extract_lora_from_dylora.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0037636f..9ae4056e 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -3,6 +3,7 @@ # Thanks to cloneofsimo import argparse +import math import os import torch from safetensors.torch import load_file, save_file, safe_open @@ -43,6 +44,7 @@ def split_lora_model(lora_sd, unit): rank = unit splitted_models = [] + new_alpha = None while rank < max_rank: print(f"Splitting rank {rank}") new_sd = {} @@ -52,9 +54,15 @@ def split_lora_model(lora_sd, unit): elif "lora_up" in key: new_sd[key] = value[:, :rank].contiguous() else: - new_sd[key] = value # alpha and other parameters + # なぜかscaleするとおかしくなる…… + # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] + # scale = math.sqrt(this_rank / rank) # rank is > unit + # print(key, value.size(), this_rank, rank, value, scale) + # new_alpha = value * scale # always same + # new_sd[key] = new_alpha + new_sd[key] = value - splitted_models.append((new_sd, rank)) + splitted_models.append((new_sd, rank, new_alpha)) rank += unit return max_rank, splitted_models @@ -68,7 +76,7 @@ def split(args): original_rank, splitted_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") - for state_dict, new_rank in splitted_models: + for state_dict, new_rank, new_alpha in splitted_models: # update metadata if metadata is None: new_metadata = {} @@ -77,6 +85,7 @@ def split(args): new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" new_metadata["ss_network_dim"] = str(new_rank) + # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash From a097c4257953013a79f76c0c4d362f66c438adb3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:07:22 +0900 Subject: [PATCH 05/12] update docs --- train_README-ja.md | 15 +++- train_network_README-ja.md | 165 ++++++++++++++++++++++++++++++++++--- train_ti_README-ja.md | 2 +- 3 files changed, 167 insertions(+), 15 deletions(-) diff --git a/train_README-ja.md b/train_README-ja.md index 032e006b..fd66458a 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -2,7 +2,7 @@ __ドキュメント更新中のため記述に誤りがあるかもしれませ # 学習について、共通編 -当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。 +当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion([XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。 # 概要 @@ -535,7 +535,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - `--debug_dataset` - このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。 + このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。 ※Linux環境(Colabを含む)では画像は表示されません。 @@ -545,6 +545,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。 +- `--cache_latents` + + 使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。 + +- `--min_snr_gamma` + + Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。 ## オプティマイザ関係 @@ -570,7 +577,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 学習率のスケジューラ関連の指定です。 - lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。 + lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。 lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。 @@ -578,6 +585,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + ### オプティマイザの指定について オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。 diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 152ff9af..cb7cd726 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -12,11 +12,31 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 +# 学習できるLoRAの種類 + +以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。 + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) + + Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) + + 1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA + +LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。 + +また学習時は __DyLoRA__ を使用することもできます(後述します)。 + ## 学習したモデルに関する注意 -cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 +LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 -WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 +LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 + +いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。 + +cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 # 学習の手順 @@ -31,9 +51,9 @@ WebUI等で画像生成する場合には、学習したLoRAのモデルを学 `train_network.py`を用います。 -`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。 +`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。 -なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。 +なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4`~`1e-3`程度を指定するとよいようです。 以下はコマンドラインの例です。 @@ -56,6 +76,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py --network_module=networks.lora ``` +このコマンドラインでは LoRA-LierLa が学習されます。 + `--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。 その他、以下のオプションが指定できます。 @@ -83,22 +105,143 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 -## LoRA を Conv2d に拡大して適用する +# その他の学習方法 -通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。 +## LoRA-C3Lier を学習する `--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。 ``` ---network_args "conv_dim=1" "conv_alpha=1" +--network_args "conv_dim=4" "conv_alpha=1" ``` 以下のように alpha 省略時は1になります。 ``` ---network_args "conv_dim=1" +--network_args "conv_dim=4" ``` +## DyLoRA + +DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。 + +論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。 + +当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。 + +### 当リポジトリのDyLoRAの特徴 + +学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。 + +DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。 + +### DyLoRAで学習する + +`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。 + +また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit`は`network_dim`を割り切れる値(`network_dim`は`unit`の倍数)としてください。 + +`unit`を指定しない場合は、`unit=1`として扱われます。 + +記述例は以下です。 + +``` +--network_module=networks.dylora --network_dim=16 --network_args "unit=4" + +--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4" +``` + +DyLoRA-C3Lierの場合は、`--network_args` に`"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim`は`network_dim`と同じ値である必要があります。記述例は以下です。 + +``` +--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4" + +--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8" +``` + +たとえばdim=16、unit=4(後述)で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。 + +その他のオプションは通常のLoRAと同じです。 + +※ `unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くなることが予想されるため、分割単位を大きくしたものです。 + +### DyLoRAのモデルからLoRAモデルを抽出する + +`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。 + +コマンドラインはたとえば以下のようになります。 + +```powershell +python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4 +``` + +`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定します(rankの数値がファイル名に付加されます)。`--unit` にはDyLoRAの学習時の`unit`を指定します。 + +## 階層別学習率 + +詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 + +フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 + +`--network_args` で以下の引数を指定してください。 + +- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 +- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 +- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 +- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 +- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 + +### 階層別学習率コマンドライン指定例: + +```powershell +--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5" + +--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5" +``` + +### 階層別学習率tomlファイル指定例: + +```toml +network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",] + +network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ] +``` + +## 階層別dim (rank) + +フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 + +`--network_args` で以下の引数を指定してください。 + +- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 +- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。 +- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。 +- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。 + +### 階層別dim (rank)コマンドライン指定例: + +```powershell +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" + +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2" + +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2" +``` + +### 階層別dim (rank)tomlファイル指定例: + +```toml +network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",] + +network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",] +``` + +# その他のスクリプト + +マージ等LoRAに関連するスクリプト群です。 + ## マージスクリプトについて merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。 @@ -323,14 +466,14 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256 - 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。 -## 追加情報 +# 追加情報 -### cloneofsimo氏のリポジトリとの違い +## cloneofsimo氏のリポジトリとの違い 2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。 またモジュール入れ替え機構は全く異なります。 -### 将来拡張について +## 将来拡張について LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 diff --git a/train_ti_README-ja.md b/train_ti_README-ja.md index 90873696..86f45a5d 100644 --- a/train_ti_README-ja.md +++ b/train_ti_README-ja.md @@ -4,7 +4,7 @@ 実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 -学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。 +学習したモデルはWeb UIでもそのまま使えます。 # 学習の手順 From 9ff32fd4c01668749058e1b7f2f2a87b3a5e6ca0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:14:20 +0900 Subject: [PATCH 06/12] fix parameters are not freezed --- networks/dylora.py | 104 ++++++++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index e588813e..c6c782fc 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -50,15 +50,17 @@ class DyLoRAModule(torch.nn.Module): kernel_size = org_module.kernel_size self.stride = org_module.stride self.padding = org_module.padding - self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size))) - self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1))) + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) else: - self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim))) - self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim))) + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_B) + for lora in self.lora_A: + torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) + for lora in self.lora_B: + torch.nn.init.zeros_(lora) self.multiplier = multiplier self.org_module = org_module # remove in applying @@ -76,38 +78,18 @@ class DyLoRAModule(torch.nn.Module): trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit # 一部のパラメータを固定して、残りのパラメータを学習する + for i in range(0, trainable_rank): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False + for i in range(trainable_rank, trainable_rank + self.unit): + self.lora_A[i].requires_grad = True + self.lora_B[i].requires_grad = True + for i in range(trainable_rank + self.unit, self.lora_dim): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False - # make lora_A - if trainable_rank > 0: - lora_A_nt1 = [self.lora_A[:trainable_rank].detach()] - else: - lora_A_nt1 = [] - - lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit] - - if trainable_rank < self.lora_dim - self.unit: - lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()] - else: - lora_A_nt2 = [] - - lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0) - - # make lora_B - if trainable_rank > 0: - lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()] - else: - lora_B_nt1 = [] - - lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit] - - if trainable_rank < self.lora_dim - self.unit: - lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()] - else: - lora_B_nt2 = [] - - lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1) - - # print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size()) + lora_A = torch.cat(tuple(self.lora_A), dim=0) + lora_B = torch.cat(tuple(self.lora_B), dim=1) # calculate with lora_A and lora_B if self.is_conv2d_3x3: @@ -116,13 +98,13 @@ class DyLoRAModule(torch.nn.Module): else: ab = x if self.is_conv2d: - ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) + ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) ab = torch.nn.functional.linear(ab, lora_A) ab = torch.nn.functional.linear(ab, lora_B) if self.is_conv2d: - ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) + ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W) # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) @@ -131,34 +113,52 @@ class DyLoRAModule(torch.nn.Module): return result def state_dict(self, destination=None, prefix="", keep_vars=False): - # state dictを通常のLoRAと同じにする - state_dict = super().state_dict(destination, prefix, keep_vars) + # state dictを通常のLoRAと同じにする: + # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える + sd = super().state_dict(destination, prefix, keep_vars) - lora_A_weight = state_dict.pop(self.lora_name + ".lora_A") + lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) if self.is_conv2d and not self.is_conv2d_3x3: lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) - state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight - lora_B_weight = state_dict.pop(self.lora_name + ".lora_B") + lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) if self.is_conv2d and not self.is_conv2d_3x3: lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) - state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight - return state_dict + sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() + sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() + + i = 0 + while True: + key_a = f"{self.lora_name}.lora_A.{i}" + key_b = f"{self.lora_name}.lora_B.{i}" + if key_a in sd: + sd.pop(key_a) + sd.pop(key_b) + else: + break + i += 1 + return sd def load_state_dict(self, state_dict, strict=True): # 通常のLoRAと同じstate dictを読み込めるようにする state_dict = state_dict.copy() - lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight") + lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) + lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) + + if lora_A_weight is None or lora_B_weight is None: + if strict: + raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") + else: + return + if self.is_conv2d and not self.is_conv2d_3x3: lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) - state_dict[self.lora_name + ".lora_A"] = lora_A_weight - - lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight") - if self.is_conv2d and not self.is_conv2d_3x3: lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) - state_dict[self.lora_name + ".lora_B"] = lora_B_weight + + state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))}) + state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))}) super().load_state_dict(state_dict, strict=strict) From a8632b7329b8c6f558f0c707b21d5ead40cb33cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:14:39 +0900 Subject: [PATCH 07/12] fix latents disk cache --- library/train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6b398707..013cc81c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -758,15 +758,15 @@ class BaseDataset(torch.utils.data.Dataset): cache_available = False expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 if os.path.exists(info.latents_npz): - cached_latents = np.load(info.latents_npz) - if cached_latents["latents"].shape[1:3] == expected_latents_size: + cached_latents = np.load(info.latents_npz)["arr_0"] + if cached_latents.shape[1:3] == expected_latents_size: cache_available = True if subset.flip_aug: cache_available = False if os.path.exists(info.latents_npz_flipped): - cached_latents_flipped = np.load(info.latents_npz_flipped) - if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: + cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] + if cached_latents_flipped.shape[1:3] == expected_latents_size: cache_available = True if cache_available: From 2de9a51591f8474bb7c5ccb534050949ccf6aa9c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:18:18 +0900 Subject: [PATCH 08/12] fix typos --- networks/extract_lora_from_dylora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 9ae4056e..0abee983 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -43,7 +43,7 @@ def split_lora_model(lora_sd, unit): print(f"Max rank: {max_rank}") rank = unit - splitted_models = [] + split_models = [] new_alpha = None while rank < max_rank: print(f"Splitting rank {rank}") @@ -62,10 +62,10 @@ def split_lora_model(lora_sd, unit): # new_sd[key] = new_alpha new_sd[key] = value - splitted_models.append((new_sd, rank, new_alpha)) + split_models.append((new_sd, rank, new_alpha)) rank += unit - return max_rank, splitted_models + return max_rank, split_models def split(args): @@ -73,10 +73,10 @@ def split(args): lora_sd, metadata = load_state_dict(args.model) print("Splitting Model...") - original_rank, splitted_models = split_lora_model(lora_sd, args.unit) + original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") - for state_dict, new_rank, new_alpha in splitted_models: + for state_dict, new_rank, new_alpha in split_models: # update metadata if metadata is None: new_metadata = {} From 9fc27403b2457a56a914773951727ee1b6227847 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:40:34 +0900 Subject: [PATCH 09/12] support disk cache: same as #164, might fix #407 --- fine_tune.py | 5 +++-- train_textual_inversion.py | 21 ++++++++++++--------- train_textual_inversion_XTI.py | 3 ++- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 47454670..61f6c191 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -275,7 +275,7 @@ def train(args): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -313,7 +313,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 88ddebdd..611adff7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -185,10 +185,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) + current_epoch = Value("i", 0) + current_step = Value("i", 0) ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -264,7 +264,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -339,7 +341,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 text_encoder.train() @@ -359,7 +361,7 @@ def train(args): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - # weight_dtype) use float instead of fp16/bf16 because text encoder is float + # use float instead of fp16/bf16 because text encoder is float encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # Sample noise that we'll add to the latents @@ -377,7 +379,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training @@ -387,9 +390,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index d302491e..54c4b4e5 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -418,7 +418,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training From 423e6c229c3b0a4747b622f0b5f3c6ac09a32c6c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 22:12:13 +0900 Subject: [PATCH 10/12] support metadata json+.npz caching (no prepare) --- library/train_util.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 013cc81c..b249e61d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1200,19 +1200,27 @@ class FineTuningDataset(BaseDataset): tags_list = [] for image_key, img_md in metadata.items(): # path情報を作る + abs_path = None + + # まず画像を優先して探す if os.path.exists(image_key): abs_path = image_key - elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, image_key) + if len(paths) > 0: + abs_path = paths[0] + + # なければnpzを探す + if abs_path is None: + if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + + assert abs_path is not None, f"no image / 画像がありません: {image_key}" caption = img_md.get("caption") tags = img_md.get("tags") From 849bc24d205a35fbe1b2a4063edd7172533c1c01 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 22:24:47 +0900 Subject: [PATCH 11/12] update readme --- README.md | 83 +++++++++++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 6f117840..a22b65b6 100644 --- a/README.md +++ b/README.md @@ -127,56 +127,53 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -### 8 Apr. 2021, 2021/4/8: +### Naming of LoRA -- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution! - - Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336) - - Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI. - - This option is also applicable to token strings of the DreamBooth method. - - The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible. - - If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses. +The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository. -- 重みづけキャプションによる学習に対応しました。 AI-Casanova 氏の素晴らしい貢献に感謝します。 - - 詳細はこちらをご確認ください。[PR #336](https://github.com/kohya-ss/sd-scripts/pull/336) - - `--weighted_captions` オプションを指定してください。Textual InversionおよびXTIを除く学習スクリプトで使用可能です。 - - キャプションだけでなく DreamBooth 手法の token string でも有効です。 - - 重みづけキャプションの記法はWeb UIとほぼ同じで、`(abc)`や`[abc]`、`(abc:1.23)`などが使用できます。入れ子も可能です。 - - 括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。 - -### 6 Apr. 2023, 2023/4/6: -- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers) -- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) - - When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model. - - Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens). - - For example, specify other arguments as follows. - - `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere` - - If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private. - - If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded. - - If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed. - - In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model` - - If you specify `--async_upload`, the upload will be done asynchronously. -- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only) -- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`. - - If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match. + LoRA for Linear layers and Conv2d layers with 1x1 kernel +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers) -- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。 + In addition to 1., LoRA for Conv2d layers with 3x3 kernel + +LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI. -- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。 - - `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。 - - アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。 - - 他の引数をたとえば以下のように指定してください。 - - `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere` - - `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。 - - `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。 - - `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。 - - その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model` - - `--async_upload`オプションを指定するとアップロードを非同期で行います。 -- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。 -- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。 - - プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。 +To use LoRA-C3Liar with Web UI, please use our extension. +### LoRAの名称について + +`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) + + Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) + + 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA + +LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 + +LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。 + +### 13 Apr. 2023, 2023/4/13: + +- Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README.md) for details (currently only in Japanese). +- Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options. + - The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved. + - Multi-GPU training has not been tested. + - This feature is not tested with all combinations of datasets and training scripts, so there may be bugs. +- Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`. + +- `train_network.py`でDyLoRAをサポートしました。詳細は[こちら](./train_network_README-ja.md)をご覧ください。 +- 各学習スクリプトでlatentのディスクへのキャッシュをサポートしました。`--cache_latents`オプションに __加えて__、`--cache_latents_to_disk`オプションを指定してください。 + - 画像と同じフォルダに、拡張子 `.npz` で保存されます。`--flip_aug`オプションを指定した場合、`_flip.npz`が付いたファイルにも保存されます。 + - マルチGPUでの学習は未テストです。 + - すべてのDataset、学習スクリプトの組み合わせでテストしたわけではないため、不具合があるかもしれません。 +- `fine_tune.py`で、`fp16`および`bf16`の学習時にエラーが出る不具合に対して対策を行いました。 ## Sample image generation during training A prompt file might look like this, for example From 06a9f51431f27900c3e1c95a55964891bc700f66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 22:27:00 +0900 Subject: [PATCH 12/12] fix link in readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a22b65b6..9d077a15 100644 --- a/README.md +++ b/README.md @@ -161,14 +161,14 @@ LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください ### 13 Apr. 2023, 2023/4/13: -- Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README.md) for details (currently only in Japanese). +- Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README-ja.md#dylora) for details (currently only in Japanese). - Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options. - The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved. - Multi-GPU training has not been tested. - This feature is not tested with all combinations of datasets and training scripts, so there may be bugs. - Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`. -- `train_network.py`でDyLoRAをサポートしました。詳細は[こちら](./train_network_README-ja.md)をご覧ください。 +- `train_network.py`でDyLoRAをサポートしました。詳細は[こちら](./train_network_README-ja.md#dylora)をご覧ください。 - 各学習スクリプトでlatentのディスクへのキャッシュをサポートしました。`--cache_latents`オプションに __加えて__、`--cache_latents_to_disk`オプションを指定してください。 - 画像と同じフォルダに、拡張子 `.npz` で保存されます。`--flip_aug`オプションを指定した場合、`_flip.npz`が付いたファイルにも保存されます。 - マルチGPUでの学習は未テストです。