From dde7807b000b304018423802bb3d8e774620c489 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 Jun 2023 22:21:36 +0900 Subject: [PATCH 01/37] add rank dropout/module dropout --- networks/lora.py | 88 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 12 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 19fbbbdb..1a665fc4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name @@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout def apply_to(self): self.org_forward = self.org_module.forward @@ -68,18 +80,45 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout if self.dropout: - return ( - self.org_forward(x) - + self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale - ) - else: - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * self.scale class LoRAInfModule(LoRAModule): - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout) + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) self.org_module_ref = [org_module] # 後から参照できるように self.enabled = True @@ -395,6 +434,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_block_dims = None conv_block_alphas = None + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoder, @@ -403,6 +450,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un lora_dim=network_dim, alpha=network_alpha, dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, block_dims=block_dims, @@ -679,6 +728,8 @@ class LoRANetwork(torch.nn.Module): lora_dim=4, alpha=1, dropout=None, + rank_dropout=None, + module_dropout=None, conv_lora_dim=None, conv_alpha=None, block_dims=None, @@ -706,18 +757,22 @@ class LoRANetwork(torch.nn.Module): self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout if modules_dim is not None: print(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}") + print(f"create LoRA network from block_dims") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") print(f"block_dims: {block_dims}") print(f"block_alphas: {block_alphas}") if conv_block_dims is not None: print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}") + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") @@ -764,7 +819,16 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue - lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout) + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) loras.append(lora) return loras, skipped From 0f0158ddaac48fb424133b530599dce269c63f4c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 2 Jun 2023 07:29:59 +0900 Subject: [PATCH 02/37] scale in rank dropout, check training in dropout --- networks/lora.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 1a665fc4..aa1c9331 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -83,18 +83,18 @@ class LoRAModule(torch.nn.Module): org_forwarded = self.org_forward(x) # module dropout - if self.module_dropout: + if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return org_forwarded lx = self.lora_down(x) # normal dropout - if self.dropout: + if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) # rank dropout - if self.rank_dropout: + if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout if len(lx.size()) == 3: mask = mask.unsqueeze(1) # for Text Encoder @@ -102,9 +102,15 @@ class LoRAModule(torch.nn.Module): mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d lx = lx * mask + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + lx = self.lora_up(lx) - return org_forwarded + lx * self.multiplier * self.scale + return org_forwarded + lx * self.multiplier * scale class LoRAInfModule(LoRAModule): From ec2efe52e45caca863005505b95f5f4a575e1246 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 10:52:22 +0900 Subject: [PATCH 03/37] scale v-pred loss like noise pred --- fine_tune.py | 21 +++++++++++++++++---- library/custom_train_functions.py | 26 ++++++++++++++++++++++++-- library/train_util.py | 7 ++++++- train_db.py | 5 +++++ train_network.py | 9 +++++++-- train_textual_inversion.py | 17 +++++++++++++---- train_textual_inversion_XTI.py | 11 +++++++---- 7 files changed, 79 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 154d3be7..201d4952 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,14 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) def train(args): @@ -261,6 +268,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) @@ -327,11 +335,16 @@ def train(args): else: target = noise - if args.min_snr_gamma: - # do not mean over batch dimension for snr weight + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred: + # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss = loss.mean() # mean over batch dimension else: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f32f050e..9d0dc402 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -5,20 +5,37 @@ import re from typing import List, Optional, Union -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 - snr = torch.stack([all_snr[t] for t in timesteps]) + + noise_scheduler.all_snr = all_snr.to(device) + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper loss = loss * snr_weight return loss +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + + loss = loss * scale + return loss + + # TODO train_utilと分散しているのでどちらかに寄せる @@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) + parser.add_argument( + "--scale_v_pred_loss_like_noise_pred", + action="store_true", + help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/library/train_util.py b/library/train_util.py index 46c5c3b2..844faca7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2311,6 +2311,11 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: + raise ValueError( + "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -3638,4 +3643,4 @@ class collater_class: # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] diff --git a/train_db.py b/train_db.py index 7ec06354..c81a092d 100644 --- a/train_db.py +++ b/train_db.py @@ -26,8 +26,10 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, ) # perlin_noise, @@ -240,6 +242,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) @@ -327,6 +330,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index cd90b0a2..32258e88 100644 --- a/train_network.py +++ b/train_network.py @@ -28,9 +28,11 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, max_norm, + scale_v_prediction_loss_like_noise_prediction, ) @@ -316,7 +318,7 @@ def train(args): network.prepare_grad_etc(text_encoder, unet) - if not cache_latents: + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=weight_dtype) @@ -554,6 +556,8 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if accelerator.is_main_process: accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) @@ -658,6 +662,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -840,7 +846,6 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) - return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b73027de..8be0703d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -20,7 +20,13 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) imagenet_templates_small = [ "a photo of a {}", @@ -338,6 +344,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) @@ -412,12 +419,14 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8c8f7e8b..7b734f28 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -372,6 +372,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) @@ -451,11 +452,13 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + + loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 71a7a27319952c219848b9f5672d387d11932677 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 11:33:18 +0900 Subject: [PATCH 04/37] update readme --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/README.md b/README.md index aefc6c35..0a474442 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,46 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 3 Jun. 2023, 2023/06/03 + +- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! + - Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details. + - Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`. + +- Three types of dropout have been added to `train_network.py` and LoRA network. + - Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0. + - `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! + - `--network_dropout=0.1` specifies the dropout probability to `0.1`. + - Note that the specification method is different from LyCORIS. + - For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability. + - Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`. + - `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time. + - Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified. + - `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet. + +- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script. + - By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected. + - See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion! + +- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。 + - Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。 + - `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。 + +- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。 + - dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。 + - `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。 + - `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。 + - LyCORISとは指定方法が異なりますのでご注意ください。 + - LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。 + - `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。 + - `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。 + - それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。 + - `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。 + +- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。 + - タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。 + - 詳細はxrg氏のこちらの記事をご参照ください:[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。 + ### 31 May 2023, 2023/05/31 - Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin! From 5bec05e04523f15dcca5964a7a2053a617d50fa4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 12:42:32 +0900 Subject: [PATCH 05/37] move max_norm to lora to avoid crashing in lycoris --- README.md | 4 +++ library/custom_train_functions.py | 43 ------------------------------- networks/lora.py | 43 +++++++++++++++++++++++++++++++ train_network.py | 10 +++++-- 4 files changed, 55 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 0a474442..3b320acc 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! - Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details. - Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`. + - The networks other than LoRA in this repository (such as LyCORIS) do not support this option. - Three types of dropout have been added to `train_network.py` and LoRA network. - Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0. @@ -156,6 +157,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time. - Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified. - `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet. + - The networks other than LoRA in this repository (such as LyCORIS) do not support these options. - Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script. - By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected. @@ -164,6 +166,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。 - Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。 - `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。 + - LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。 - `train_network.py` およびLoRAに計三種類のdropoutを追加しました。 - dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。 @@ -175,6 +178,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。 - それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。 - `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。 + - これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。 - 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。 - タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。 diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9d0dc402..0cf0d1e2 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -456,46 +456,3 @@ def perlin_noise(noise, device, octaves): noise += noise_perlin # broadcast for each batch return noise / noise.std() # Scaled back to roughly unit variance """ - - -def max_norm(state_dict, max_norm_value, device): - downkeys = [] - upkeys = [] - alphakeys = [] - norms = [] - keys_scaled = 0 - - for key in state_dict.keys(): - if "lora_down" in key and "weight" in key: - downkeys.append(key) - upkeys.append(key.replace("lora_down", "lora_up")) - alphakeys.append(key.replace("lora_down.weight", "alpha")) - - for i in range(len(downkeys)): - down = state_dict[downkeys[i]].to(device) - up = state_dict[upkeys[i]].to(device) - alpha = state_dict[alphakeys[i]].to(device) - dim = down.shape[0] - scale = alpha / dim - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown *= scale - - norm = updown.norm().clamp(min=max_norm_value / 2) - desired = torch.clamp(norm, max=max_norm_value) - ratio = desired.cpu() / norm.cpu() - sqrt_ratio = ratio**0.5 - if ratio != 1: - keys_scaled += 1 - state_dict[upkeys[i]] *= sqrt_ratio - state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm() * ratio - norms.append(scalednorm.item()) - - return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora.py b/networks/lora.py index aa1c9331..9f2f5094 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1126,3 +1126,46 @@ class LoRANetwork(torch.nn.Module): org_module._lora_restored = False lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/train_network.py b/train_network.py index 32258e88..051d0d18 100644 --- a/train_network.py +++ b/train_network.py @@ -31,7 +31,6 @@ from library.custom_train_functions import ( prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, - max_norm, scale_v_prediction_loss_like_noise_prediction, ) @@ -220,6 +219,11 @@ def train(args): if hasattr(network, "prepare_network"): network.prepare_network(args) + if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): + print( + "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" + ) + args.scale_weight_norms = False train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only @@ -677,7 +681,9 @@ def train(args): optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device) + keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: keys_scaled, mean_norm, maximum_norm = None, None, None From 5db792b10b7da7ff2d523e0083430dcbd5426a36 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 19:24:47 +0900 Subject: [PATCH 06/37] initial commit for original U-Net --- gen_img_diffusers.py | 3 +- library/model_util.py | 3 +- library/original_unet.py | 1234 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1238 insertions(+), 2 deletions(-) create mode 100644 library/original_unet.py diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 27bd7460..4d8121ca 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -78,7 +78,7 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - UNet2DConditionModel, + # UNet2DConditionModel, StableDiffusionPipeline, ) from einops import rearrange @@ -95,6 +95,7 @@ import library.train_util as train_util from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI diff --git a/library/model_util.py b/library/model_util.py index 26f72235..0fbc6590 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,8 +5,9 @@ import math import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline #, UNet2DConditionModel from safetensors.torch import load_file, save_file +from library.original_unet import UNet2DConditionModel # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 diff --git a/library/original_unet.py b/library/original_unet.py new file mode 100644 index 00000000..603239d5 --- /dev/null +++ b/library/original_unet.py @@ -0,0 +1,1234 @@ +# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# コードの多くはDiffusersからコピーしている +# コードが冗長になる部分はコメント等を適宜削除する +# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある + +# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. +# Remove redundant code by deleting comments, etc. as appropriate +# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 + +""" +v1.5とv2.1の相違点は +- attention_head_dimがintかlist[int]か +- cross_attention_dimが768か1024か +- use_linear_projection: trueがない(=False, 1.5)かあるか +- upcast_attentionがFalse(1.5)かTrue(2.1)か +- (以下は多分無視していい) +- sample_sizeが64か96か +- dual_cross_attentionがあるかないか +- num_class_embedsがあるかないか +- only_cross_attentionがあるかないか + +v1.5 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 768, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ] +} + +v2.1 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "sample_size": 96, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} +""" + +import math +from typing import Dict, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F + +BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) +TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] +TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +LAYERS_PER_BLOCK: int = 2 +LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 +TIME_EMBED_FLIP_SIN_TO_COS: bool = True +TIME_EMBED_FREQ_SHIFT: int = 0 +RESNET_GROUPS: int = 32 +RESNET_EPS: float = 1e-6 +TRANSFORMER_NORM_NUM_GROUPS = 32 + +DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] +UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class SampleOutput: + def __init__(self, sample): + self.sample = sample + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=in_channels, eps=RESNET_EPS, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) + + self.norm2 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=out_channels, eps=RESNET_EPS, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + + self.use_in_shortcut = self.in_channels != self.out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + raise NotImplementedError("Memory efficient attention is not implemented for this model.") + + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ] + ) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return SampleOutput(sample=output) + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + cross_attention_dim=1280, + attn_num_head_channels=1, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + + resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + use_linear_projection=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + # Middle block has two resnets and one attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ] + attentions = [ + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + add_upsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + add_upsample=True, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def get_down_block( + down_block_type, + in_channels, + out_channels, + add_downsample, + attn_num_head_channels, + cross_attention_dim, + use_linear_projection, + upcast_attention, +): + if down_block_type == "DownBlock2D": + return DownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlock2D": + return CrossAttnDownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +def get_up_block( + up_block_type, + in_channels, + out_channels, + prev_output_channel, + add_upsample, + attn_num_head_channels, + cross_attention_dim=None, + use_linear_projection=False, + upcast_attention=False, +): + if up_block_type == "UpBlock2D": + return UpBlock2D( + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlock2D": + return CrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + attn_num_head_channels=attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +class UNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + sample_size: Optional[int] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + upcast_attention: bool = False, + **kwargs, + ): + super().__init__() + assert sample_size is not None, "sample_size must be specified" + print( + f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" + ) + + # 外部からの参照用に定義しておく + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + + self.sample_size = sample_size + + # input + self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) + + self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * 4 + + # down + output_channel = BLOCK_OUT_CHANNELS[0] + for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): + input_channel = output_channel + output_channel = BLOCK_OUT_CHANNELS[i] + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=BLOCK_OUT_CHANNELS[-1], + attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(UP_BLOCK_TYPES): + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=RESNET_GROUPS, eps=RESNET_EPS) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + + # region diffusers compatibility + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self._set_gradient_checkpointing(self, value=True) + + def disable_gradient_checkpointing(self): + self._set_gradient_checkpointing(self, value=False) + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + raise NotImplementedError("Memory efficient attention is not supported for this model.") + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + # endregion + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) + + def handle_unusual_timesteps(self, sample, timesteps): + r""" + timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 + """ + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + return timesteps From 5907bbd9de949a252ba550b27e216c8ae705c2bc Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sat, 3 Jun 2023 21:20:26 +0900 Subject: [PATCH 07/37] =?UTF-8?q?loss=E8=A1=A8=E7=A4=BA=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 051d0d18..c6ea7e4e 100644 --- a/train_network.py +++ b/train_network.py @@ -724,7 +724,7 @@ def train(args): progress_bar.set_postfix(**logs) if args.scale_weight_norms: - progress_bar.set_postfix(**max_mean_logs) + progress_bar.set_postfix(**{**max_mean_logs,**logs}) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) From c0a7df9ee14c715134f0e1cfce0a6256d1b64014 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 21:29:27 +0900 Subject: [PATCH 08/37] fix eps value, enable xformers, etc. --- gen_img_diffusers.py | 63 ++++++++++---------- library/original_unet.py | 120 +++++++++++++++++++++++++++++++++------ library/train_util.py | 66 +++++++++++---------- 3 files changed, 171 insertions(+), 78 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 4d8121ca..33b7a65c 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -317,7 +317,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_unet_cross_attn_to_xformers(unet) def replace_unet_cross_attn_to_memory_efficient(): @@ -357,50 +357,55 @@ def replace_unet_cross_attn_to_memory_efficient(): out = self.to_out[1](out) return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # TODO U-Net側に移す + from library.original_unet import CrossAttention + CrossAttention.forward = forward_flash_attn -def replace_unet_cross_attn_to_xformers(): +def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel): print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") try: import xformers.ops except ImportError: raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention_xformers(True) - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + # def forward_xformers(self, x, context=None, mask=None): + # h = self.heads + # q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + # context = default(context, x) + # context = context.to(x.dtype) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + # context_k, context_v = self.hypernetwork.forward(x, context) + # context_k = context_k.to(x.dtype) + # context_v = context_v.to(x.dtype) + # else: + # context_k = context + # context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + # k_in = self.to_k(context_k) + # v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + # del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + # q = q.contiguous() + # k = k.contiguous() + # v = v.contiguous() + # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, "b n h d -> b n (h d)", h=h) + # out = rearrange(out, "b n h d -> b n (h d)", h=h) - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # # diffusers 0.7.0~ + # out = self.to_out[0](out) + # out = self.to_out[1](out) + # return out - diffusers.models.attention.CrossAttention.forward = forward_xformers + # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): diff --git a/library/original_unet.py b/library/original_unet.py index 603239d5..47b751c1 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1,10 +1,10 @@ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# 条件分岐等で不要な部分は削除している # コードの多くはDiffusersからコピーしている -# コードが冗長になる部分はコメント等を適宜削除する # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. -# Remove redundant code by deleting comments, etc. as appropriate +# Unnecessary parts are deleted by condition branching. # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 """ @@ -111,6 +111,7 @@ from typing import Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F +from einops import rearrange BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -121,8 +122,8 @@ LAYERS_PER_BLOCK: int = 2 LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 TIME_EMBED_FLIP_SIN_TO_COS: bool = True TIME_EMBED_FREQ_SHIFT: int = 0 -RESNET_GROUPS: int = 32 -RESNET_EPS: float = 1e-6 +NORM_GROUPS: int = 32 +NORM_EPS: float = 1e-5 TRANSFORMER_NORM_NUM_GROUPS = 32 DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] @@ -233,13 +234,13 @@ class ResnetBlock2D(nn.Module): self.in_channels = in_channels self.out_channels = out_channels - self.norm1 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=in_channels, eps=RESNET_EPS, affine=True) + self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) - self.norm2 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=out_channels, eps=RESNET_EPS, affine=True) + self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) # if non_linearity == "swish": @@ -304,6 +305,9 @@ class DownBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + pass + def forward(self, hidden_states, temb=None): output_states = () @@ -372,6 +376,11 @@ class CrossAttention(nn.Module): self.to_out.append(nn.Linear(inner_dim, query_dim)) # no dropout here + self.use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, value): + self.use_memory_efficient_attention_xformers = value + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads @@ -387,6 +396,9 @@ class CrossAttention(nn.Module): return tensor def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + query = self.to_q(hidden_states) context = context if context is not None else hidden_states key = self.to_k(context) @@ -427,6 +439,30 @@ class CrossAttention(nn.Module): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -506,8 +542,9 @@ class BasicTransformerBlock(nn.Module): # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - raise NotImplementedError("Memory efficient attention is not implemented for this model.") + def set_use_memory_efficient_attention_xformers(self, value: bool): + self.attn1.set_use_memory_efficient_attention_xformers(value) + self.attn2.set_use_memory_efficient_attention_xformers(value) def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention @@ -566,6 +603,10 @@ class Transformer2DModel(nn.Module): else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + def set_use_memory_efficient_attention_xformers(self, value): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input batch, _, height, weight = hidden_states.shape @@ -643,6 +684,10 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -714,11 +759,37 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample - hidden_states = resnet(hidden_states, temb) + for i, resnet in enumerate(self.resnets): + attn = None if i == 0 else self.attentions[i - 1] + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + if attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -792,6 +863,9 @@ class UpBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + pass + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states @@ -868,6 +942,10 @@ class CrossAttnUpBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward( self, hidden_states, @@ -991,6 +1069,8 @@ class UNet2DConditionModel(nn.Module): self.sample_size = sample_size + # state_dictの書式が変わるのでmoduleの持ち方は変えられない + # input self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) @@ -1069,7 +1149,7 @@ class UNet2DConditionModel(nn.Module): prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=RESNET_GROUPS, eps=RESNET_EPS) + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) @@ -1088,16 +1168,20 @@ class UNet2DConditionModel(nn.Module): return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self): - self._set_gradient_checkpointing(self, value=True) + self.set_gradient_checkpointing(value=True) def disable_gradient_checkpointing(self): - self._set_gradient_checkpointing(self, value=False) + self.set_gradient_checkpointing(value=False) def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: - raise NotImplementedError("Memory efficient attention is not supported for this model.") + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_memory_efficient_attention_xformers(valid) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + def set_gradient_checkpointing(self, value=False): + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + print(module.__class__.__name__, module.gradient_checkpointing, "->", value) module.gradient_checkpointing = value # endregion diff --git a/library/train_util.py b/library/train_util.py index 844faca7..b7cee937 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1792,7 +1792,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_unet_cross_attn_to_xformers(unet) def replace_unet_cross_attn_to_memory_efficient(): @@ -1827,55 +1827,59 @@ def replace_unet_cross_attn_to_memory_efficient(): out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) out = self.to_out[0](out) - out = self.to_out[1](out) + # out = self.to_out[1](out) return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # diffusers.models.attention.CrossAttention.forward = forward_flash_attn + from library.original_unet import CrossAttention + + CrossAttention.forward = forward_flash_attn -def replace_unet_cross_attn_to_xformers(): +def replace_unet_cross_attn_to_xformers(unet): print("CrossAttention.forward has been replaced to enable xformers.") try: import xformers.ops except ImportError: raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + unet.set_use_memory_efficient_attention_xformers(True) - context = default(context, x) - context = context.to(x.dtype) + # def forward_xformers(self, x, context=None, mask=None): + # h = self.heads + # q_in = self.to_q(x) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + # context = default(context, x) + # context = context.to(x.dtype) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + # context_k, context_v = self.hypernetwork.forward(x, context) + # context_k = context_k.to(x.dtype) + # context_v = context_v.to(x.dtype) + # else: + # context_k = context + # context_v = context - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + # k_in = self.to_k(context_k) + # v_in = self.to_v(context_v) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + # del q_in, k_in, v_in - out = rearrange(out, "b n h d -> b n (h d)", h=h) + # q = q.contiguous() + # k = k.contiguous() + # v = v.contiguous() + # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # out = rearrange(out, "b n h d -> b n (h d)", h=h) - diffusers.models.attention.CrossAttention.forward = forward_xformers + # # diffusers 0.7.0~ + # out = self.to_out[0](out) + # out = self.to_out[1](out) + # return out + + # diffusers.models.attention.CrossAttention.forward = forward_xformers """ From 24823b061df14e0d5a947ba453997aaaa7b5a903 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 6 Jun 2023 21:53:58 +0900 Subject: [PATCH 09/37] support BREAK in generation script --- gen_img_diffusers.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 27bd7460..33c40441 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -457,7 +457,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform upsampler.forward = make_replacer(upsampler) """ - + def replace_vae_attn_to_memory_efficient(): print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") @@ -1795,6 +1795,9 @@ def parse_prompt_attention(text): for p in range(start_position, len(res)): res[p][1] *= multiplier + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) @@ -1826,7 +1829,7 @@ def parse_prompt_attention(text): # merge runs of identical weights i = 0 while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": res[i][0] += res[i + 1][0] res.pop(i + 1) else: @@ -1843,11 +1846,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens = [] weights = [] truncated = False + for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(pipe.tokenizer.eos_token_id) + # else: + text_token.append(pipe.tokenizer.pad_token_id) + text_weight.append(1.0) + continue + # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] From bb91a10b5f5c3947bbc715963ff2b91fd7e18719 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 6 Jun 2023 21:59:57 +0900 Subject: [PATCH 10/37] fix to work LyCORIS<0.1.6 --- networks/lora.py | 4 ++-- train_network.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 9f2f5094..27f59344 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -400,7 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs): return down_lr_weight, mid_lr_weight, up_lr_weight -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -455,7 +455,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha, - dropout=dropout, + dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, conv_lora_dim=conv_dim, diff --git a/train_network.py b/train_network.py index c6ea7e4e..b62aef7e 100644 --- a/train_network.py +++ b/train_network.py @@ -212,7 +212,7 @@ def train(args): else: # LyCORIS will work with this... network = network_module.create_network( - 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs + 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs ) if network is None: return @@ -724,7 +724,7 @@ def train(args): progress_bar.set_postfix(**logs) if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs,**logs}) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) From 4e24733f1c51573e5a423d78e821f65ee509f480 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 6 Jun 2023 22:03:21 +0900 Subject: [PATCH 11/37] update readme --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 3b320acc..cb3803f0 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 6 Jun. 2023, 2023/06/06 + +- Fix `train_network.py` to probably work with older versions of LyCORIS. +- `gen_img_diffusers.py` now supports `BREAK` syntax. +- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。 +- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。 + ### 3 Jun. 2023, 2023/06/03 - Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! From dccdb8771c6facfb40087fc670e2698a2deb3bce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 7 Jun 2023 08:12:52 +0900 Subject: [PATCH 12/37] support sample generation in training --- library/original_unet.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 47b751c1..e8920727 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -107,6 +107,7 @@ v2.1 """ import math +from types import SimpleNamespace from typing import Dict, Optional, Tuple, Union import torch from torch import nn @@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module): return next(parameter.parameters()).dtype +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, @@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module): self.out_channels = OUT_CHANNELS self.sample_size = sample_size + self.prepare_config() # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module): self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + @property def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). return get_parameter_dtype(self) + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + def set_attention_slice(self, slice_size): raise NotImplementedError("Attention slicing is not supported for this model.") From 045cd38b6e1e4d1a341f27980068d773aef1433e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Jun 2023 22:02:46 +0900 Subject: [PATCH 13/37] fix clip_skip not work in weight capt, sample gen --- library/custom_train_functions.py | 12 ++++++------ library/lpw_stable_diffusion.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0cf0d1e2..8b44874b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -265,11 +265,6 @@ def get_unweighted_text_embeddings( text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0] - if no_boseos_middle: if i == 0: # discard the ending token @@ -284,7 +279,12 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings = text_encoder(text_input)[0] + if clip_skip is None or clip_skip == 1: + text_embeddings = text_encoder(text_input)[0] + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b887..58b1171e 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -245,11 +245,6 @@ def get_unweighted_text_embeddings( text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0] - if no_boseos_middle: if i == 0: # discard the ending token @@ -264,7 +259,12 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings = pipe.text_encoder(text_input)[0] + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings From 8088c04a71534ececd129b10312e451b42ade2b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Jun 2023 22:06:34 +0900 Subject: [PATCH 14/37] update readme --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index cb3803f0..8234a89e 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 8 Jun. 2023, 2023/06/08 + +- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training. +- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。 + ### 6 Jun. 2023, 2023/06/06 - Fix `train_network.py` to probably work with older versions of LyCORIS. From 334d07bf96e0dc1f2e21864b6a21816e2f26f003 Mon Sep 17 00:00:00 2001 From: mio <74481573+mio2333@users.noreply.github.com> Date: Thu, 8 Jun 2023 23:39:06 +0800 Subject: [PATCH 15/37] Update make_captions.py Append sys path for make_captions.py to load blip module in the same folder to fix the error when you don't run this script under the folder --- finetune/make_captions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 9e51037f..b20c4106 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -3,6 +3,7 @@ import glob import os import json import random +import sys from pathlib import Path from PIL import Image @@ -11,6 +12,7 @@ import numpy as np import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode +sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder import library.train_util as train_util From cc274fb7fb0e5d081009feffe815d91af4ece02b Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 16:54:10 +0900 Subject: [PATCH 16/37] update diffusers ver, remove tensorflow --- requirements.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 801cf321..7252f745 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -accelerate==0.15.0 +accelerate==0.16.0 transformers==4.26.0 ftfy==6.1.1 albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.6.0 -diffusers[torch]==0.10.2 +diffusers[torch]==0.17.0 pytorch-lightning==1.9.0 bitsandbytes==0.35.0 tensorboard==2.10.1 @@ -14,13 +14,12 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 +huggingface-hub==0.13.3 # for BLIP captioning requests==2.28.2 timm==0.6.12 fairscale==0.4.13 # for WD14 captioning -# tensorflow<2.11 -tensorflow==2.10.1 -huggingface-hub==0.13.3 +# tensorflow==2.10.1 # for kohya_ss library . From 7f6b581ef82933bf5f1695b17676efeb64ce7b32 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 16:54:41 +0900 Subject: [PATCH 17/37] support memory efficient attn (not xformers) --- library/original_unet.py | 242 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 225 insertions(+), 17 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index e8920727..0e64280b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -131,6 +131,187 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] +# region memory effcient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + def get_parameter_dtype(parameter: torch.nn.Module): return next(parameter.parameters()).dtype @@ -310,7 +491,7 @@ class DownBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): pass def forward(self, hidden_states, temb=None): @@ -382,9 +563,11 @@ class CrossAttention(nn.Module): # no dropout here self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False - def set_use_memory_efficient_attention_xformers(self, value): - self.use_memory_efficient_attention_xformers = value + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -403,6 +586,8 @@ class CrossAttention(nn.Module): def forward(self, hidden_states, context=None, mask=None): if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) query = self.to_q(hidden_states) context = context if context is not None else hidden_states @@ -468,6 +653,29 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -547,9 +755,9 @@ class BasicTransformerBlock(nn.Module): # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, value: bool): - self.attn1.set_use_memory_efficient_attention_xformers(value) - self.attn2.set_use_memory_efficient_attention_xformers(value) + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention @@ -608,9 +816,9 @@ class Transformer2DModel(nn.Module): else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for transformer in self.transformer_blocks: - transformer.set_use_memory_efficient_attention_xformers(value) + transformer.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input @@ -689,9 +897,9 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -766,9 +974,9 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for i, resnet in enumerate(self.resnets): @@ -868,7 +1076,7 @@ class UpBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): pass def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): @@ -947,9 +1155,9 @@ class CrossAttnUpBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward( self, @@ -1185,10 +1393,10 @@ class UNet2DConditionModel(nn.Module): def disable_gradient_checkpointing(self): self.set_gradient_checkpointing(value=False) - def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None: modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - module.set_use_memory_efficient_attention_xformers(valid) + module.set_use_memory_efficient_attention(xformers,mem_eff) def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks From 4e25c8f78e873224e753af6d0509bc9281acff65 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 16:57:17 +0900 Subject: [PATCH 18/37] fix to work with Diffusers 0.17.0 --- gen_img_diffusers.py | 399 ++++++-------------------------- library/lpw_stable_diffusion.py | 4 +- library/model_util.py | 24 +- library/train_util.py | 2 +- 4 files changed, 86 insertions(+), 343 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 28f7323a..34857af3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -82,7 +82,6 @@ from diffusers import ( StableDiffusionPipeline, ) from einops import rearrange -from torch import einsum from tqdm import tqdm from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig @@ -96,6 +95,7 @@ from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel +from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -136,341 +136,36 @@ USE_CUTOUTS = False 高速化のためのモジュール入れ替え """ -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE -# constants +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") -EPSILON = 1e-6 + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None + unet.set_use_memory_efficient_attention(True, False) # TODO common train_util.py -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers(unet) - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - # diffusers.models.attention.CrossAttention.forward = forward_flash_attn - # TODO U-Net側に移す - from library.original_unet import CrossAttention - CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel): - print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention_xformers(True) - - # def forward_xformers(self, x, context=None, mask=None): - # h = self.heads - # q_in = self.to_q(x) - - # context = default(context, x) - # context = context.to(x.dtype) - - # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - # context_k, context_v = self.hypernetwork.forward(x, context) - # context_k = context_k.to(x.dtype) - # context_v = context_v.to(x.dtype) - # else: - # context_k = context - # context_v = context - - # k_in = self.to_k(context_k) - # v_in = self.to_v(context_v) - - # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - # del q_in, k_in, v_in - - # q = q.contiguous() - # k = k.contiguous() - # v = v.contiguous() - # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - # out = rearrange(out, "b n h d -> b n (h d)", h=h) - - # # diffusers 0.7.0~ - # out = self.to_out[0](out) - # out = self.to_out[1](out) - # return out - - # diffusers.models.attention.CrossAttention.forward = forward_xformers - - def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): if mem_eff_attn: replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") - vae.set_use_memory_efficient_attention_xformers(True) - - """ - # VAEがbfloat16でメモリ消費が大きい問題を解決する - upsamplers = [] - for block in vae.decoder.up_blocks: - if block.upsamplers is not None: - upsamplers.extend(block.upsamplers) - - def forward_upsample(_self, hidden_states, output_size=None): - assert hidden_states.shape[1] == _self.channels - if _self.use_conv_transpose: - return _self.conv(hidden_states) - - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - assert output_size is None - # repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する - hidden_states = hidden_states.repeat_interleave(2, dim=-1) - hidden_states = hidden_states.repeat_interleave(2, dim=-2) - else: - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest") - - if _self.use_conv: - if _self.name == "conv": - hidden_states = _self.conv(hidden_states) - else: - hidden_states = _self.Conv2d_0(hidden_states) - return hidden_states - - # replace upsamplers - for upsampler in upsamplers: - # make new scope - def make_replacer(upsampler): - def forward(hidden_states, output_size=None): - return forward_upsample(upsampler, hidden_states, output_size) - - return forward - - upsampler.forward = make_replacer(upsampler) -""" - + replace_vae_attn_to_xformers() def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func =FlashAttentionFunction - def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 @@ -483,12 +178,12 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) ) out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) @@ -496,14 +191,62 @@ def replace_vae_attn_to_memory_efficient(): out = rearrange(out, "b h n d -> b n (h d)") # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + diffusers.models.attention_processor.Attention.forward = forward_xformers # endregion diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b887..883707f7 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -464,7 +464,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - clip_skip: int, + # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, @@ -479,7 +479,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, ) - self.clip_skip = clip_skip + # self.clip_skip = clip_skip self.__init__additional__() # else: diff --git a/library/model_util.py b/library/model_util.py index 0fbc6590..ea1be513 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,7 +5,7 @@ import math import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline #, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel @@ -127,17 +127,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -192,8 +192,8 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -362,7 +362,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): # SDのv2では1*1のconv2dがlinearに変わっている # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 - if v2 and not config.get('use_linear_projection', False): + if v2 and not config.get("use_linear_projection", False): linear_transformer_to_conv(new_checkpoint) return new_checkpoint diff --git a/library/train_util.py b/library/train_util.py index b7cee937..8aa7f987 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3467,11 +3467,11 @@ def sample_images( unet=unet, tokenizer=tokenizer, scheduler=scheduler, - clip_skip=args.clip_skip, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) + pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する pipeline.to(device) save_dir = args.output_dir + "/sample" From 035dd3a900ce9208c5d703c354ec839166b8f9cc Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 17:08:21 +0900 Subject: [PATCH 19/37] fix mem_eff_attn does not work --- library/original_unet.py | 10 ++-- library/train_util.py | 101 ++++----------------------------------- 2 files changed, 15 insertions(+), 96 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 0e64280b..36318eb9 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -278,7 +278,7 @@ class FlashAttentionFunction(torch.autograd.Function): for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( @@ -293,14 +293,14 @@ class FlashAttentionFunction(torch.autograd.Function): p = exp_attn_weights / lc - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) D = (doc * oc).sum(dim=-1, keepdims=True) ds = p * scale * (dp - D) - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) dqc.add_(dq_chunk) dkc.add_(dk_chunk) diff --git a/library/train_util.py b/library/train_util.py index 8aa7f987..7d7eb325 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,7 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util +from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1787,100 +1788,18 @@ class FlashAttentionFunction(torch.autograd.function.Function): return dq, dk, dv, None, None, None, None -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - # unet is not used currently, but it is here for future use +def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() + print("Enable memory efficient attention for U-Net") + unet.set_use_memory_efficient_attention(False, True) elif xformers: - replace_unet_cross_attn_to_xformers(unet) - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - out = self.to_out[0](out) - # out = self.to_out[1](out) - return out - - # diffusers.models.attention.CrossAttention.forward = forward_flash_attn - from library.original_unet import CrossAttention - - CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(unet): - print("CrossAttention.forward has been replaced to enable xformers.") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention_xformers(True) - - # def forward_xformers(self, x, context=None, mask=None): - # h = self.heads - # q_in = self.to_q(x) - - # context = default(context, x) - # context = context.to(x.dtype) - - # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - # context_k, context_v = self.hypernetwork.forward(x, context) - # context_k = context_k.to(x.dtype) - # context_v = context_v.to(x.dtype) - # else: - # context_k = context - # context_v = context - - # k_in = self.to_k(context_k) - # v_in = self.to_v(context_v) - - # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - # del q_in, k_in, v_in - - # q = q.contiguous() - # k = k.contiguous() - # v = v.contiguous() - # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - # out = rearrange(out, "b n h d -> b n (h d)", h=h) - - # # diffusers 0.7.0~ - # out = self.to_out[0](out) - # out = self.to_out[1](out) - # return out - - # diffusers.models.attention.CrossAttention.forward = forward_xformers + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + unet.set_use_memory_efficient_attention(True, False) """ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): From 4b7b3bc04a5e0d6b74b4e92ade5bbacb3f095c9c Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 17:35:00 +0900 Subject: [PATCH 20/37] fix saved SD dict is invalid for VAE --- library/model_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index ea1be513..0773188c 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -783,10 +783,10 @@ def convert_vae_state_dict(vae_state_dict): vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), ] mapping = {k: k for k in vae_state_dict.keys()} @@ -804,7 +804,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") + # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict From 0315611b11c92ab7ad60ac82bf285d8461b78910 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 18:32:14 +0900 Subject: [PATCH 21/37] remove workaround for accelerator=0.15, fix XTI --- XTI_hijack.py | 210 ++++++++++++++++----------------- fine_tune.py | 14 +-- library/train_util.py | 18 +-- train_db.py | 14 +-- train_network.py | 8 +- train_textual_inversion.py | 14 ++- train_textual_inversion_XTI.py | 34 ++++-- 7 files changed, 153 insertions(+), 159 deletions(-) diff --git a/XTI_hijack.py b/XTI_hijack.py index f39cc8e7..36b5d3f2 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -2,132 +2,123 @@ import torch from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput -def unet_forward_XTI(self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. +from library.original_unet import SampleOutput - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None +def unet_forward_XTI( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, +) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 - t_emb = self.time_proj(timesteps) + t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) - if self.config.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + # 2. pre-process + sample = self.conv_in(sample) - # 2. pre-process - sample = self.conv_in(sample) + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - # 3. down - down_block_res_samples = (sample,) - down_i = 0 - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], - ) - down_i += 2 - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples - down_block_res_samples += res_samples + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 - # 5. up - up_i = 7 - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], - upsample_size=upsample_size, - ) - up_i += 3 - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) - if not return_dict: - return (sample,) + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) - return UNet2DConditionOutput(sample=sample) def downblock_forward_XTI( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None @@ -166,6 +157,7 @@ def downblock_forward_XTI( return hidden_states, output_states + def upblock_forward_XTI( self, hidden_states, @@ -199,11 +191,11 @@ def upblock_forward_XTI( else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample - + i += 1 if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/fine_tune.py b/fine_tune.py index 201d4952..881845c5 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,7 +91,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -385,8 +385,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -428,8 +428,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -437,8 +437,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/library/train_util.py b/library/train_util.py index 7d7eb325..3ae5d0f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,23 +2904,9 @@ def prepare_accelerator(args: argparse.Namespace): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, - logging_dir=logging_dir, + project_dir=logging_dir, ) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - return accelerator, unwrap_model + return accelerator def prepare_dtype(args: argparse.Namespace): diff --git a/train_db.py b/train_db.py index c81a092d..09f8d361 100644 --- a/train_db.py +++ b/train_db.py @@ -95,7 +95,7 @@ def train(args): f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -372,8 +372,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -420,8 +420,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -429,8 +429,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/train_network.py b/train_network.py index b62aef7e..7c74ae5d 100644 --- a/train_network.py +++ b/train_network.py @@ -150,7 +150,7 @@ def train(args): # acceleratorを準備する print("preparing accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする @@ -702,7 +702,7 @@ def train(args): accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, unwrap_model(network), global_step, epoch) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -744,7 +744,7 @@ def train(args): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -762,7 +762,7 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) if is_main_process: - network = unwrap_model(network) + network = accelerator.unwrap_model(network) accelerator.end_training() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8be0703d..9dd846bd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -98,7 +98,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -291,7 +291,7 @@ def train(args): index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) @@ -440,7 +440,7 @@ def train(args): # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ index_no_updates ] @@ -457,7 +457,9 @@ def train(args): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs = ( + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + ) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model(ckpt_name, updated_embs, global_step, epoch) @@ -493,7 +495,7 @@ def train(args): accelerator.wait_for_everyone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs @@ -517,7 +519,7 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - text_encoder = unwrap_model(text_encoder) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7b734f28..1ea6dfc6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -11,6 +11,7 @@ import torch from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler +import library import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -20,7 +21,14 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -98,7 +106,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -257,9 +265,9 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + original_unet.UNet2DConditionModel.forward = unet_forward_XTI + original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI + original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI # 学習を準備する if cache_latents: @@ -319,7 +327,7 @@ def train(args): index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) @@ -473,7 +481,7 @@ def train(args): # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ index_no_updates ] @@ -490,7 +498,13 @@ def train(args): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids_XTI] + .data.detach() + .clone() + ) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model(ckpt_name, updated_embs, global_step, epoch) @@ -526,7 +540,7 @@ def train(args): accelerator.wait_for_everyone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs @@ -551,7 +565,7 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - text_encoder = unwrap_model(text_encoder) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() From 4d0c06e3975e03f4f4e7c648602824d272151297 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 18:54:50 +0900 Subject: [PATCH 22/37] support both 0.10.2 and 0.17.0 for Diffusers --- gen_img_diffusers.py | 84 ++++++++++++++++++++++++++++++++++++++++--- library/model_util.py | 66 +++++++++++++++++++++++++--------- 2 files changed, 129 insertions(+), 21 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 34857af3..71daa9a1 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -161,9 +161,45 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ replace_vae_attn_to_xformers() + def replace_vae_attn_to_memory_efficient(): print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func =FlashAttentionFunction + flash_func = FlashAttentionFunction + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 @@ -202,13 +238,50 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention_processor.Attention.forward = forward_flash_attn + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn def replace_vae_attn_to_xformers(): print("VAE: Attention.forward has been replaced to xformers") import xformers.ops - + + def forward_xformers_0_14(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + def forward_xformers(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -246,7 +319,10 @@ def replace_vae_attn_to_xformers(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention_processor.Attention.forward = forward_xformers + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers # endregion diff --git a/library/model_util.py b/library/model_util.py index 0773188c..7ca8f194 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,6 +4,7 @@ import math import os import torch +import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file @@ -127,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") + if diffusers.__version__ < "0.15.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -192,7 +206,15 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = False + if diffusers.__version__ < "0.15.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -780,14 +802,24 @@ def convert_vae_state_dict(vae_state_dict): sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "to_q."), - ("k.", "to_k."), - ("v.", "to_v."), - ("proj_out.", "to_out.0."), - ] + if diffusers.__version__ < "0.15.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] mapping = {k: k for k in vae_state_dict.keys()} for k, v in mapping.items(): From 9e1683cf2b85438c3ddc54b9c545a12bcf670e6d Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 21:26:15 +0900 Subject: [PATCH 23/37] support sdpa --- README.md | 26 +++++- fine_tune.py | 2 +- gen_img_diffusers.py | 158 ++++++++++++++++++--------------- library/original_unet.py | 61 ++++++++++++- library/train_util.py | 6 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 177 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 8234a89e..634b9494 100644 --- a/README.md +++ b/README.md @@ -75,8 +75,6 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set accelerate config ``` -update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python). - Answers to accelerate config: ```txt @@ -94,6 +92,30 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +### Experimental: Use PyTorch 2.0 + +In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following: + +```powershell +git clone https://github.com/kohya-ss/sd-scripts.git +cd sd-scripts + +python -m venv venv +.\venv\Scripts\activate + +pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install --upgrade -r requirements.txt +pip install xformers==0.0.20 + +cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py + +accelerate config +``` + +Answers to accelerate config should be the same as above. + ### about PyTorch and xformers Other versions of PyTorch and xformers seem to have problems with training. diff --git a/fine_tune.py b/fine_tune.py index 881845c5..120f3d0f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -141,7 +141,7 @@ def train(args): # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 71daa9a1..889b4c4c 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -137,7 +137,7 @@ USE_CUTOUTS = False """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: print("Enable memory efficient attention for U-Net") @@ -151,56 +151,26 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio raise ImportError("No xformers / xformersがインストールされていないようです") unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) # TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): if mem_eff_attn: replace_vae_attn_to_memory_efficient() elif xformers: - # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ replace_vae_attn_to_xformers() + elif sdpa: + replace_vae_attn_to_sdpa() def replace_vae_attn_to_memory_efficient(): print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 @@ -238,6 +208,15 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 else: @@ -248,40 +227,6 @@ def replace_vae_attn_to_xformers(): print("VAE: Attention.forward has been replaced to xformers") import xformers.ops - def forward_xformers_0_14(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - def forward_xformers(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -319,12 +264,75 @@ def replace_vae_attn_to_xformers(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 else: diffusers.models.attention_processor.Attention.forward = forward_xformers +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + # endregion # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 @@ -2082,8 +2090,9 @@ def main(args): # xformers、Hypernetwork対応 if not args.diffusers_xformers: - replace_unet_modules(unet, not args.xformers, args.xformers) - replace_vae_modules(vae, not args.xformers, args.xformers) + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む print("loading tokenizer") @@ -3176,6 +3185,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") parser.add_argument( "--diffusers_xformers", action="store_true", diff --git a/library/original_unet.py b/library/original_unet.py index 36318eb9..e22b16c0 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -494,6 +494,9 @@ class DownBlock2D(nn.Module): def set_use_memory_efficient_attention(self, xformers, mem_eff): pass + def set_use_sdpa(self, sdpa): + pass + def forward(self, hidden_states, temb=None): output_states = () @@ -564,11 +567,15 @@ class CrossAttention(nn.Module): self.use_memory_efficient_attention_xformers = False self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads @@ -588,6 +595,8 @@ class CrossAttention(nn.Module): return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) query = self.to_q(hidden_states) context = context if context is not None else hidden_states @@ -676,6 +685,26 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def forward_sdpa(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -759,6 +788,10 @@ class BasicTransformerBlock(nn.Module): self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states) @@ -820,6 +853,10 @@ class Transformer2DModel(nn.Module): for transformer in self.transformer_blocks: transformer.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input batch, _, height, weight = hidden_states.shape @@ -901,6 +938,10 @@ class CrossAttnDownBlock2D(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -978,6 +1019,10 @@ class UNetMidBlock2DCrossAttn(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for i, resnet in enumerate(self.resnets): attn = None if i == 0 else self.attentions[i - 1] @@ -1079,6 +1124,9 @@ class UpBlock2D(nn.Module): def set_use_memory_efficient_attention(self, xformers, mem_eff): pass + def set_use_sdpa(self, sdpa): + pass + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states @@ -1159,6 +1207,10 @@ class CrossAttnUpBlock2D(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, spda): + for attn in self.attentions: + attn.set_use_sdpa(spda) + def forward( self, hidden_states, @@ -1393,10 +1445,15 @@ class UNet2DConditionModel(nn.Module): def disable_gradient_checkpointing(self): self.set_gradient_checkpointing(value=False) - def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None: + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - module.set_use_memory_efficient_attention(xformers,mem_eff) + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_sdpa(sdpa) def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks diff --git a/library/train_util.py b/library/train_util.py index 3ae5d0f3..30380262 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1788,7 +1788,7 @@ class FlashAttentionFunction(torch.autograd.function.Function): return dq, dk, dv, None, None, None, None -def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): +def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: print("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) @@ -1800,6 +1800,9 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): raise ImportError("No xformers / xformersがインストールされていないようです") unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_sdpa(True) """ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): @@ -2048,6 +2051,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)") parser.add_argument( "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" ) diff --git a/train_db.py b/train_db.py index 09f8d361..895b8b24 100644 --- a/train_db.py +++ b/train_db.py @@ -119,7 +119,7 @@ def train(args): use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/train_network.py b/train_network.py index 7c74ae5d..9ea9bf9c 100644 --- a/train_network.py +++ b/train_network.py @@ -160,7 +160,7 @@ def train(args): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 差分追加学習のためにモデルを読み込む import sys diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9dd846bd..4f31220d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -231,7 +231,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 1ea6dfc6..69f618cc 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -264,7 +264,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) original_unet.UNet2DConditionModel.forward = unet_forward_XTI original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI From 0dfffcd88ac46214ac8d5596c3060f87843f010f Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 21:46:05 +0900 Subject: [PATCH 24/37] remove unnecessary import --- library/original_unet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index e22b16c0..94d11290 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -686,8 +686,6 @@ class CrossAttention(nn.Module): return out def forward_sdpa(self, x, context=None, mask=None): - import xformers.ops - h = self.heads q_in = self.to_q(x) context = context if context is not None else x From 67f09b7d7ed247c14133feb584519428a396f329 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Jun 2023 12:29:44 +0900 Subject: [PATCH 25/37] change ver no for Diffusers VAE changing --- library/model_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index 7ca8f194..d59f5ef4 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -128,7 +128,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - if diffusers.__version__ < "0.15.0": + if diffusers.__version__ < "0.17.0": new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.bias", "query.bias") @@ -207,7 +207,7 @@ def assign_to_checkpoint( # proj_attn.weight has to be converted from conv 1D to linear reshaping = False - if diffusers.__version__ < "0.15.0": + if diffusers.__version__ < "0.17.0": if "proj_attn.weight" in new_path: reshaping = True else: @@ -802,7 +802,7 @@ def convert_vae_state_dict(vae_state_dict): sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - if diffusers.__version__ < "0.15.0": + if diffusers.__version__ < "0.17.0": vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ("norm.", "group_norm."), From 9aee793078ac3ab20bf1296756013458fabc78ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Jun 2023 12:49:12 +0900 Subject: [PATCH 26/37] support arbitrary dataset for train_network.py --- library/train_util.py | 61 ++++++++++++++++++++++++++++++ train_network.py | 87 +++++++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 33 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 844faca7..e1046d58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive): return image_paths +class MinimalDataset(BaseDataset): + def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + self.num_train_images = 0 # update in subclass + self.num_reg_images = 0 # update in subclass + self.datasets = [self] + self.batch_size = 1 # update in subclass + + self.subsets = [self] + self.num_repeats = 1 # update in subclass if needed + self.img_count = 1 # update in subclass if needed + self.bucket_info = {} + self.is_reg = False + self.image_dir = "dummy" # for metadata + + def is_latent_cacheable(self) -> bool: + return False + + def __len__(self): + raise NotImplementedError + + # override to avoid shuffling buckets + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def __getitem__(self, idx): + r""" + The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. + + Returns: example like this: + + for i in range(batch_size): + image_key = ... # whatever hashable + image_keys.append(image_key) + + image = ... # PIL Image + img_tensor = self.image_transforms(img) + images.append(img_tensor) + + caption = ... # str + input_ids = self.get_input_ids(caption) + input_ids_list.append(input_ids) + + captions.append(caption) + + images = torch.stack(images, dim=0) + input_ids_list = torch.stack(input_ids_list, dim=0) + example = { + "images": images, + "input_ids": input_ids_list, + "captions": captions, # for debug_dataset + "latents": None, + "image_keys": image_keys, # for debug_dataset + "loss_weights": torch.ones(batch_size, dtype=torch.float32), + } + return example + """ + raise NotImplementedError + + # endregion # region モジュール入れ替え部 diff --git a/train_network.py b/train_network.py index b62aef7e..6f845b5a 100644 --- a/train_network.py +++ b/train_network.py @@ -92,42 +92,56 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - if use_dreambooth_method: - print("Using DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } else: - print("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: train_util.MinimalDataset = dataset_class( + tokenizer, args.max_token_length, args.resolution, args.debug_dataset + ) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -185,6 +199,7 @@ def train(args): module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") print(f"all weights merged: {', '.join(args.base_weights)}") + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -852,6 +867,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset / 任意のデータセットのクラス名", + ) return parser From 449ad7502cb0f36cd8b94b2c7d98ec204af234a9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Jun 2023 22:26:05 +0900 Subject: [PATCH 27/37] use original unet for HF models, don't download TE --- gen_img_diffusers.py | 17 +++++++++++------ library/model_util.py | 29 +++++++++++++++++++++++++---- library/train_util.py | 17 ++++++++++++++++- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 889b4c4c..2c36329e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う - -DEFAULT_TOKEN_LENGTH = 75 - # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -2066,6 +2060,17 @@ def main(args): tokenizer = loading_pipe.tokenizer del loading_pipe + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) diff --git a/library/model_util.py b/library/model_util.py index d59f5ef4..63a395f8 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -933,10 +933,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt else: converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) - logging.set_verbosity_warning() - + # logging.set_verbosity_error() # don't show annoying warning + # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + # logging.set_verbosity_warning() + # print(f"config: {text_model.config}") + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + torch_dtype="float32", + ) + text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) print("loading text encoder:", info) diff --git a/library/train_util.py b/library/train_util.py index 30380262..f13d7252 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -36,7 +36,6 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer import transformers -import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -52,6 +51,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, ) +from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import albumentations as albu import numpy as np @@ -2947,11 +2947,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): print( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) + raise ex text_encoder = pipe.text_encoder vae = pipe.vae unet = pipe.unet del pipe + # Diffusers U-Net to original U-Net + # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう + # print(f"unet config: {unet.config}") + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + print("U-Net converted to original U-Net") + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) From d4ba37f54399ce81c3b1a3c1260c6dbf9ab447e9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 13:22:06 +0900 Subject: [PATCH 28/37] supprot dynamic prompt variants --- gen_img_diffusers.py | 295 +++++++++++++++++++++++++++++-------------- 1 file changed, 203 insertions(+), 92 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 33c40441..01001646 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,6 +46,7 @@ VGG( ) """ +import itertools import json from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob @@ -2159,6 +2160,102 @@ def preprocess_mask(mask): return mask +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separater = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separater)) + else: + # make random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separater)) + + # make each prompt + if not enumerating: + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0]) + prompts.append(current) + else: + prompts = [prompt] + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: # enumerating + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement)) + prompts = new_prompts + for found, replacer in zip(founds, replacers): + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) + + return prompts + + # endregion @@ -2776,6 +2873,7 @@ def main(args): # seed指定時はseedを決めておく if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう random.seed(args.seed) predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] if len(predefined_seeds) == 1: @@ -3058,121 +3156,134 @@ def main(args): while not valid: print("\nType prompt:") try: - prompt = input() + raw_prompt = input() except EOFError: break - valid = len(prompt.strip().split(" --")[0].strip()) > 0 + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 if not valid: # EOF, end app break else: - prompt = prompt_list[prompt_index] + raw_prompt = prompt_list[prompt_index] - # parse prompt - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None + # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - prompt_args = prompt.strip().split(" --") - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + # repeat prompt + for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0] - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue + if prompt_index == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue - if seeds is not None: - # 数が足りないなら繰り返す - if len(seeds) < args.images_per_prompt: - seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) - seeds = seeds[: args.images_per_prompt] - else: - if predefined_seeds is not None: - seeds = predefined_seeds[-args.images_per_prompt :] - predefined_seeds = predefined_seeds[: -args.images_per_prompt] - elif args.iter_same_seed: - seeds = [iter_seed] * args.images_per_prompt + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) else: - seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seeds}") + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None - init_image = mask_image = guide_image = None - for seed in seeds: # images_per_promptの数だけ # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する if init_images is not None: init_image = init_images[global_step % len(init_images)] From 624fbadea2b742f2bf32d82efeb332f24695881c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 19:19:16 +0900 Subject: [PATCH 29/37] fix dynamic prompt with from_file --- gen_img_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 01001646..7b5cee1f 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3170,10 +3170,10 @@ def main(args): raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) # repeat prompt - for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0] + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if prompt_index == 0 or len(raw_prompts) > 1: + if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing width = args.W height = args.H From f2989b36c2dfcd799460a22a90de5d7bdba8d2ec Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:37:01 +0900 Subject: [PATCH 30/37] fix typos, add comment --- gen_img_diffusers.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7b5cee1f..acff1ea4 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2180,12 +2180,14 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): enumerating = False replacers = [] for found in founds: + # if "e$$" is found, enumerate all variants found_enumerating = found.group(2) is not None enumerating = enumerating or found_enumerating - separater = ", " if found.group(6) is None else found.group(6) + separator = ", " if found.group(6) is None else found.group(6) variants = found.group(7).split("|") + # parse count range count_range = found.group(4) if count_range is None: count_range = [1, 1] @@ -2206,7 +2208,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): count_range[1] = len(variants) if found_enumerating: - # make all combinations + # make function to enumerate all combinations def make_replacer_enum(vari, cr, sep): def replacer(): values = [] @@ -2217,9 +2219,9 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): return replacer - replacers.append(make_replacer_enum(variants, count_range, separater)) + replacers.append(make_replacer_enum(variants, count_range, separator)) else: - # make random combinations + # make function to choose random combinations def make_replacer_single(vari, cr, sep): def replacer(): count = random.randint(cr[0], cr[1]) @@ -2228,10 +2230,11 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): return replacer - replacers.append(make_replacer_single(variants, count_range, separater)) + replacers.append(make_replacer_single(variants, count_range, separator)) # make each prompt - if not enumerating: + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly prompts = [] for _ in range(repeat_count): current = prompt @@ -2239,16 +2242,21 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): current = current.replace(found.group(0), replacer()[0]) prompts.append(current) else: + # if enumerating, iterate all combinations for previous prompts prompts = [prompt] + for found, replacer in zip(founds, replacers): - if found.group(2) is not None: # enumerating + if found.group(2) is not None: + # make all combinations for existing prompts new_prompts = [] for current in prompts: replecements = replacer() for replecement in replecements: new_prompts.append(current.replace(found.group(0), replecement)) prompts = new_prompts + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts if found.group(2) is None: for i in range(len(prompts)): prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) @@ -3166,7 +3174,8 @@ def main(args): else: raw_prompt = prompt_list[prompt_index] - # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) # repeat prompt From 9806b00f74d1ee6be4d792e107cbd1b59b7addbb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:39:39 +0900 Subject: [PATCH 31/37] add arbitrary dataset feature to each script --- fine_tune.py | 54 ++++++++++++++------------ library/train_util.py | 17 +++++++- train_db.py | 40 ++++++++++--------- train_network.py | 14 +------ train_textual_inversion.py | 71 ++++++++++++++++++---------------- train_textual_inversion_XTI.py | 11 +++++- 6 files changed, 115 insertions(+), 92 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 201d4952..308f90ef 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -42,33 +42,37 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/library/train_util.py b/library/train_util.py index e1046d58..4a25e00d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1579,6 +1579,15 @@ class MinimalDataset(BaseDataset): raise NotImplementedError +def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset) + return train_dataset_group + + # endregion # region モジュール入れ替え部 @@ -2455,7 +2464,6 @@ def add_dataset_arguments( default=1, help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", ) - parser.add_argument( "--token_warmup_step", type=float, @@ -2463,6 +2471,13 @@ def add_dataset_arguments( help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに diff --git a/train_db.py b/train_db.py index c81a092d..115855c1 100644 --- a/train_db.py +++ b/train_db.py @@ -46,26 +46,30 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_network.py b/train_network.py index 6f845b5a..abec3d41 100644 --- a/train_network.py +++ b/train_network.py @@ -135,13 +135,7 @@ def train(args): train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - module = ".".join(args.dataset_class.split(".")[:-1]) - dataset_class = args.dataset_class.split(".")[-1] - module = importlib.import_module(module) - dataset_class = getattr(module, dataset_class) - train_dataset_group: train_util.MinimalDataset = dataset_class( - tokenizer, args.max_token_length, args.resolution, args.debug_dataset - ) + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -867,12 +861,6 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) - parser.add_argument( - "--dataset_class", - type=str, - default=None, - help="dataset class for arbitrary dataset / 任意のデータセットのクラス名", - ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8be0703d..48713fc1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -153,43 +153,46 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } else: - print("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7b734f28..bf7d5bb0 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,13 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -88,6 +94,9 @@ def train(args): print( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) + assert ( + args.dataset_class is None + ), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません" cache_latents = args.cache_latents From f0bb3ae825efe6720f10301ee788072542b2e3ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:56:12 +0900 Subject: [PATCH 32/37] add an option to disable controlnet in 2nd stage --- gen_img_diffusers.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index acff1ea4..93a876ab 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -615,11 +615,15 @@ class PipelineLike: # ControlNet self.control_nets: List[ControlNetInfo] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + def replace_token(self, tokens, layer=None): new_tokens = [] for token in tokens: @@ -1112,7 +1116,7 @@ class PipelineLike: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - if self.control_nets: + if self.control_nets and self.control_net_enabled: if reginonal_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -2233,7 +2237,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): replacers.append(make_replacer_single(variants, count_range, separator)) # make each prompt - if not enumerating: + if not enumerating: # if not enumerating, repeat the prompt, replace each variant randomly prompts = [] for _ in range(repeat_count): @@ -2254,7 +2258,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for replecement in replecements: new_prompts.append(current.replace(found.group(0), replecement)) prompts = new_prompts - + for found, replacer in zip(founds, replacers): # make random selection for existing prompts if found.group(2) is None: @@ -2933,6 +2937,8 @@ def main(args): ext.num_sub_prompts, ) batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する @@ -2976,6 +2982,9 @@ def main(args): batch_2nd.append(bd_2nd) batch = batch_2nd + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + # このバッチの情報を取り出す ( return_latents, @@ -3574,6 +3583,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) parser.add_argument( "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" From e97d67a68121df2ec57270d131c76ec8cb2e312d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Thu, 15 Jun 2023 20:12:53 +0800 Subject: [PATCH 33/37] Support for Prodigy(Dadapt variety for Dylora) (#585) * Update train_util.py for DAdaptLion * Update train_README-zh.md for dadaptlion * Update train_README-ja.md for DAdaptLion * add DAdatpt V3 * Alignment * Update train_util.py for experimental * Update train_util.py V3 * Update train_README-zh.md * Update train_README-ja.md * Update train_util.py fix * Update train_util.py * support Prodigy * add lower --- docs/train_README-ja.md | 1 + docs/train_README-zh.md | 3 ++- fine_tune.py | 2 +- library/train_util.py | 32 ++++++++++++++++++++++++++++++++ train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 8 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index b64b1808..158363b3 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdanIP : 引数は同上 - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 678832d2..454d5456 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdam : 参数同上 - DAdaptAdaGrad : 参数同上 - DAdaptAdan : 参数同上 - - DAdaptAdanIP : 引数は同上 + - DAdaptAdanIP : 参数同上 - DAdaptLion : 参数同上 - DAdaptSGD : 参数同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 diff --git a/fine_tune.py b/fine_tune.py index 308f90ef..d0013d53 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -397,7 +397,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/train_util.py b/library/train_util.py index 4a25e00d..5b5d99ac 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2808,6 +2808,38 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Prodigy".lower(): + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") + + # check lr and lr_count, and print warning + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: diff --git a/train_db.py b/train_db.py index 115855c1..927e79de 100644 --- a/train_db.py +++ b/train_db.py @@ -384,7 +384,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index abec3d41..da0ca1c9 100644 --- a/train_network.py +++ b/train_network.py @@ -57,7 +57,7 @@ def generate_step_logs( logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -67,7 +67,7 @@ def generate_step_logs( for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()): + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 48713fc1..d08251e1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -476,7 +476,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index bf7d5bb0..f44d565c 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -515,7 +515,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From 5845de7d7c6c9d8dd6123e7b29f39302e8a8140a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 21:47:37 +0900 Subject: [PATCH 34/37] common lr checking for dadaptation and prodigy --- library/train_util.py | 114 +++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 67 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5b5d99ac..acfb503b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2752,15 +2752,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - elif optimizer_type.startswith("DAdapt".lower()): - # DAdaptation family - # check dadaptation is installed - try: - import dadaptation - import dadaptation.experimental as experimental - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - + elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): # check lr and lr_count, and print warning actual_lr = lr lr_count = 1 @@ -2773,72 +2765,60 @@ def get_optimizer(args, trainable_params): if actual_lr <= 0.1: print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) print("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) - # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdaGrad".lower(): - optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdan".lower(): - optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptLion".lower(): - optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptSGD".lower(): - optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + if optimizer_type.startswith("DAdapt".lower()): + # DAdaptation family + # check dadaptation is installed + try: + import dadaptation + import dadaptation.experimental as experimental + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + + # set optimizer + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = experimental.DAdaptAdamPreprint + print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdaGrad".lower(): + optimizer_class = dadaptation.DAdaptAdaGrad + print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdan".lower(): + optimizer_class = dadaptation.DAdaptAdan + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = experimental.DAdaptAdanIP + print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptSGD".lower(): + optimizer_class = dadaptation.DAdaptSGD + print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Prodigy".lower(): - # Prodigy - # check Prodigy is installed - try: - import prodigyopt - except ImportError: - raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - - # check lr and lr_count, and print warning - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - print(f"use Prodigy optimizer | {optimizer_kwargs}") - optimizer_class = prodigyopt.Prodigy - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する From 18156bf2a18f29e56d8f7dbb9de71d09399dde1d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 22:22:12 +0900 Subject: [PATCH 35/37] fix same replacement multiple times in dyn prompt --- gen_img_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 93a876ab..ffb79aa3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2243,7 +2243,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for _ in range(repeat_count): current = prompt for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0]) + current = current.replace(found.group(0), replacer()[0], 1) prompts.append(current) else: # if enumerating, iterate all combinations for previous prompts @@ -2256,14 +2256,14 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for current in prompts: replecements = replacer() for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement)) + new_prompts.append(current.replace(found.group(0), replecement, 1)) prompts = new_prompts for found, replacer in zip(founds, replacers): # make random selection for existing prompts if found.group(2) is None: for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) return prompts From 5d1b54de45c142261d7d93467d94ef14e369188d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 22:27:47 +0900 Subject: [PATCH 36/37] update readme --- README.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/README.md b/README.md index 8234a89e..e6202bae 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,42 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 15 Jun. 2023, 2023/06/15 + +- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds! + - Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`. +- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions. + - Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script. + - Please refer to `MinimalDataset` for implementation. I will prepare a sample later. +- The following features have been added to the generation script. + - Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny. + - Added Variants similar to sd-dynamic-propmpts in the prompt. + - If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected. + - If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected. + - If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `. + - You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`. + - It can also be specified for the prompt option. + - If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots. + - You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3. + - There is no weighting function. + +- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。 + - `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。 +- 各学習スクリプトで任意のDatasetをサポートしました(XTIを除く)。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。 + - Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。 + - 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。 +- 生成スクリプトに以下の機能追加を行いました。 + - Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。 + - プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。 + - `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。 + - `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。 + - `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。 + - 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。 + - プロンプトオプションに対しても指定可能です。 + - `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます)。X/Y plotの作成に便利かもしれません。 + - `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。 + - Weightingの機能はありません。 + ### 8 Jun. 2023, 2023/06/08 - Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training. From 1e0b0599821f97819eb47ca76cdcc1a54eb5bb83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Jun 2023 12:10:18 +0900 Subject: [PATCH 37/37] fix same seed is used for multiple generation --- gen_img_diffusers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index ffb79aa3..9ac5cd17 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3294,6 +3294,9 @@ def main(args): seed = None elif args.iter_same_seed: seeds = iter_seed + else: + seed = None # 前のを消す + if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: