From a26522597256abc1e6ddc40936dd0e6c2179c266 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Mon, 20 Mar 2023 22:51:38 +0000 Subject: [PATCH 01/28] Min-SNR Weighting Strategy --- library/train_util.py | 2 +- train_network.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 7d311827..e68444a0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1927,7 +1927,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_network.py b/train_network.py index 7f910df4..5cb08f15 100644 --- a/train_network.py +++ b/train_network.py @@ -548,6 +548,16 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + gamma = args.min_snr_gamma + if gamma: + sigma = torch.sub(noisy_latents, latents) #find noise as applied + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square + snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64c923230e54a92b1a132e4918dad190935da461 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Wed, 22 Mar 2023 01:25:49 +0000 Subject: [PATCH 02/28] Min-SNR Weighting Strategy: Refactored and added to all trainers --- fine_tune.py | 8 +++++++- library/custom_train_functions.py | 17 +++++++++++++++++ library/train_util.py | 2 +- train_db.py | 7 ++++++- train_network.py | 17 ++++++----------- train_textual_inversion.py | 6 ++++++ 6 files changed, 43 insertions(+), 14 deletions(-) create mode 100644 library/custom_train_functions.py diff --git a/fine_tune.py b/fine_tune.py index 1acf478f..ff33eb9c 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -19,7 +19,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -304,6 +305,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -396,6 +400,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 00000000..f60ec743 --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,17 @@ +import torch +import argparse + +def apply_snr_weight(loss, latents, noisy_latents, gamma): + sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + print(snr_weight) + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index ffe81d69..a0e98cb1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1963,7 +1963,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_db.py b/train_db.py index 527f8e9b..ee9beda9 100644 --- a/train_db.py +++ b/train_db.py @@ -21,7 +21,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -291,6 +292,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) @@ -390,6 +394,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--no_token_padding", diff --git a/train_network.py b/train_network.py index dce70618..715da8c1 100644 --- a/train_network.py +++ b/train_network.py @@ -23,7 +23,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -548,16 +549,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - gamma = args.min_snr_gamma - if gamma: - sigma = torch.sub(noisy_latents, latents) #find noise as applied - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares - gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper - loss = loss * snr_weight + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -662,6 +656,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85f0d57c..5fe662f6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -17,6 +17,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ "a photo of a {}", @@ -377,6 +379,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -534,6 +539,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as", From a9b26b73e0134884e945e0f5d15c6e804e046759 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 07:37:14 +0900 Subject: [PATCH 03/28] implement token warmup --- fine_tune.py | 4 +++ library/config_util.py | 6 +++++ library/train_util.py | 55 ++++++++++++++++++++++++++++++++++++-- train_db.py | 4 +++ train_network.py | 4 +++ train_textual_inversion.py | 4 +++ 6 files changed, 75 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index d927bd73..473a13ec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -197,6 +197,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -263,6 +266,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index e62bfb89..98d89b7e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -56,6 +56,8 @@ class BaseSubsetParams: caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: Union[float,int] = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -137,6 +139,8 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Union[float,int], } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, """), " ") if is_dreambooth: diff --git a/library/train_util.py b/library/train_util.py index 7d311827..52b51314 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -277,6 +277,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float,int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -290,6 +292,9 @@ class BaseSubset: self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.img_count = 0 @@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.current_step: int = 0 + self.max_train_steps: int = 0 + # augmentation self.aug_helper = AugHelper() @@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch = epoch self.shuffle_buckets() + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) + tokens = fixed_tokens + flex_tokens - caption = ", ".join(fixed_tokens + flex_tokens) + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + + caption = ", ".join(tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_current_epoch(epoch) + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() @@ -2001,6 +2038,20 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + + parser.add_argument( + "--token_warmup_steps", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに diff --git a/train_db.py b/train_db.py index 81aeda19..164e354e 100644 --- a/train_db.py +++ b/train_db.py @@ -162,6 +162,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end @@ -246,6 +249,7 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 7f910df4..16f41ebb 100644 --- a/train_network.py +++ b/train_network.py @@ -200,6 +200,9 @@ def train(args): if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -505,6 +508,7 @@ def train(args): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e4ab7b5c..b3467d94 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -260,6 +260,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -338,6 +341,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) From 447c56bf505c2a84d00e88ac173a1b6961894429 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 09:53:14 +0900 Subject: [PATCH 04/28] =?UTF-8?q?typo=E4=BF=AE=E6=AD=A3=E3=80=81step?= =?UTF-8?q?=E3=82=92global=5Fstep=E3=81=AB=E4=BF=AE=E6=AD=A3=E3=80=81?= =?UTF-8?q?=E3=83=90=E3=82=B0=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 2 +- library/config_util.py | 4 ++-- library/train_util.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 473a13ec..def942fa 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -265,8 +265,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index 98d89b7e..84bbf308 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -57,7 +57,7 @@ class BaseSubsetParams: caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 - token_warmup_step: Union[float,int] = 0 + token_warmup_step: float = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -140,7 +140,7 @@ class ConfigSanitizer: "shuffle_caption": bool, "keep_tokens": int, "token_warmup_min": int, - "token_warmup_step": Union[float,int], + "token_warmup_step": Any(float,int), } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { diff --git a/library/train_util.py b/library/train_util.py index 52b51314..83e9372b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2046,7 +2046,7 @@ def add_dataset_arguments( ) parser.add_argument( - "--token_warmup_steps", + "--token_warmup_step", type=float, default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", diff --git a/train_db.py b/train_db.py index 164e354e..e17a8b79 100644 --- a/train_db.py +++ b/train_db.py @@ -241,6 +241,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") @@ -249,7 +250,6 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 16f41ebb..6d23ab07 100644 --- a/train_network.py +++ b/train_network.py @@ -507,8 +507,8 @@ def train(args): network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b3467d94..42746169 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -340,8 +340,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) From dbadc40ec2eb2de92b21fd3b5aa82994899705cc Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 12:33:03 +0900 Subject: [PATCH 05/28] =?UTF-8?q?persistent=5Fworkers=E3=82=92=E6=9C=89?= =?UTF-8?q?=E5=8A=B9=E3=81=AB=E3=81=97=E3=81=9F=E9=9A=9B=E3=81=AB=E3=82=AD?= =?UTF-8?q?=E3=83=A3=E3=83=97=E3=82=B7=E3=83=A7=E3=83=B3=E3=81=8C=E5=A4=89?= =?UTF-8?q?=E5=8C=96=E3=81=97=E3=81=AA=E3=81=8F=E3=81=AA=E3=82=8B=E3=83=90?= =?UTF-8?q?=E3=82=B0=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 3 ++- library/config_util.py | 8 ++++++++ train_db.py | 3 ++- train_network.py | 3 ++- train_textual_inversion.py | 3 ++- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index def942fa..ff580435 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -62,6 +62,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -259,13 +260,13 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/config_util.py b/library/config_util.py index 84bbf308..efeb8016 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,6 +497,14 @@ def load_user_config(file: str) -> dict: return config +def blueprint_args_conflict(args,blueprint:Blueprint): + # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする + for b in blueprint.dataset_group.datasets: + for t in b.subsets: + if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir)) + t.params.caption_dropout_every_n_epochs = 0 + t.params.token_warmup_step = 0 # for config test if __name__ == "__main__": diff --git a/train_db.py b/train_db.py index e17a8b79..87fe771b 100644 --- a/train_db.py +++ b/train_db.py @@ -57,6 +57,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.no_token_padding: @@ -233,6 +234,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -241,7 +243,6 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index 6d23ab07..02a2d925 100644 --- a/train_network.py +++ b/train_network.py @@ -98,6 +98,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -501,13 +502,13 @@ def train(args): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 42746169..63b63426 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -183,6 +183,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 @@ -335,12 +336,12 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From a3c7d711e4f1160fb873755c77074a89f5bf700a Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Tue, 21 Mar 2023 20:38:27 -0500 Subject: [PATCH 06/28] Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation --- library/custom_train_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f60ec743..5e880c9a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -10,8 +10,8 @@ def apply_snr_weight(loss, latents, noisy_latents, gamma): gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper loss = loss * snr_weight - print(snr_weight) + #print(snr_weight) return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") \ No newline at end of file + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") From 518a18aeff9c2f9a830565e0623729fd938590d3 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Thu, 23 Mar 2023 12:34:49 +0000 Subject: [PATCH 07/28] (ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation --- fine_tune.py | 2 +- library/custom_train_functions.py | 20 ++++++++++++-------- train_db.py | 3 ++- train_network.py | 4 +--- train_textual_inversion.py | 2 +- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff33eb9c..45f4b9db 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -306,7 +306,7 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 5e880c9a..b080b40c 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,16 +1,20 @@ import torch import argparse +import numpy as np -def apply_snr_weight(loss, latents, noisy_latents, gamma): - sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.to(loss.device) + snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper loss = loss * snr_weight - #print(snr_weight) return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): diff --git a/train_db.py b/train_db.py index ee9beda9..52195b92 100644 --- a/train_db.py +++ b/train_db.py @@ -293,7 +293,8 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 715da8c1..145dd600 100644 --- a/train_network.py +++ b/train_network.py @@ -489,7 +489,6 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) - if accelerator.is_main_process: accelerator.init_trackers("network_train") @@ -529,7 +528,6 @@ def train(args): # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -551,7 +549,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5fe662f6..0694dbb6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -381,7 +381,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights From 143c26e55219ba9fe51fd1b50f8922d2f2de9c8a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 24 Mar 2023 13:08:56 +0900 Subject: [PATCH 08/28] =?UTF-8?q?=E7=AB=B6=E5=90=88=E6=99=82=E3=81=ABpersi?= =?UTF-8?q?stant=5Fdata=5Floader=E5=81=B4=E3=82=92=E7=84=A1=E5=8A=B9?= =?UTF-8?q?=E3=81=AB=E3=81=99=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB=E5=A4=89?= =?UTF-8?q?=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index efeb8016..9c8c90c2 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -502,9 +502,8 @@ def blueprint_args_conflict(args,blueprint:Blueprint): for b in blueprint.dataset_group.datasets: for t in b.subsets: if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir)) - t.params.caption_dropout_every_n_epochs = 0 - t.params.token_warmup_step = 0 + print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) + args.persistent_data_loader_workers = False # for config test if __name__ == "__main__": From 1b89b2a10e1f623efd3945d422dcd0640ac4f0fd Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 24 Mar 2023 13:44:30 +0900 Subject: [PATCH 09/28] =?UTF-8?q?=E3=82=B7=E3=83=A3=E3=83=83=E3=83=95?= =?UTF-8?q?=E3=83=AB=E5=89=8D=E3=81=AB=E3=82=BF=E3=82=B0=E3=82=92=E5=88=87?= =?UTF-8?q?=E3=82=8A=E8=A9=B0=E3=82=81=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB?= =?UTF-8?q?=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/train_util.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 83e9372b..d1df9c58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: return tokens @@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset): return l fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] + flex_tokens = tokens[:] if subset.keep_tokens > 0: fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens :] + flex_tokens = tokens[subset.keep_tokens :] if subset.shuffle_caption: random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) - tokens = fixed_tokens + flex_tokens - if subset.token_warmup_step < 1: - subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) - if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min - tokens = tokens[:tokens_len] - - caption = ", ".join(tokens) + caption = ", ".join(fixed_tokens + flex_tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): From b2c5b96f2a4e77fab69831567bf2c43095c3332c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 24 Mar 2023 20:19:05 +0900 Subject: [PATCH 10/28] format by black --- gen_img_diffusers.py | 4909 ++++++++++++++++++++++-------------------- 1 file changed, 2590 insertions(+), 2319 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 38bc86e9..94ec8179 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,11 +65,22 @@ import diffusers import numpy as np import torch import torchvision -from diffusers import (AutoencoderKL, DDPMScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - UNet2DConditionModel, StableDiffusionPipeline) +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + UNet2DConditionModel, + StableDiffusionPipeline, +) from einops import rearrange from torch import einsum from tqdm import tqdm @@ -86,7 +97,7 @@ from tools.original_control_net import ControlNetInfo # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う DEFAULT_TOKEN_LENGTH = 75 @@ -94,7 +105,7 @@ DEFAULT_TOKEN_LENGTH = 75 SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = 'scaled_linear' +SCHEDLER_SCHEDULE = "scaled_linear" # その他の設定 LATENT_CHANNELS = 4 @@ -133,11 +144,12 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d + return val if exists(val) else d + # flash attention forwards and backwards @@ -145,243 +157,247 @@ def default(val, d): class FlashAttentionFunction(torch.autograd.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = (q.shape[-1] ** -0.5) + scale = q.shape[-1] ** -0.5 - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 @@ -389,1071 +405,1168 @@ def replace_unet_cross_attn_to_xformers(): # Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 -class PipelineLike(): - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ +class PipelineLike: + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - clip_model: CLIPModel, - clip_guidance_scale: float, - clip_image_guidance_scale: float, - vgg16_model: torchvision.models.VGG, - vgg16_guidance_scale: float, - vgg16_layer_no: int, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + clip_model: CLIPModel, + clip_guidance_scale: float, + clip_image_guidance_scale: float, + vgg16_model: torchvision.models.VGG, + vgg16_guidance_scale: float, + vgg16_layer_no: int, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) - self.vae = vae - self.text_encoder = text_encoder - self.tokenizer = tokenizer - self.unet = unet - self.scheduler = scheduler - self.safety_checker = None + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.unet = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion + self.token_replacements = {} + + # CLIP guidance + self.clip_guidance_scale = clip_guidance_scale + self.clip_image_guidance_scale = clip_image_guidance_scale + self.clip_model = clip_model + self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) + self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + + # VGG16 guidance + self.vgg16_guidance_scale = vgg16_guidance_scale + if self.vgg16_guidance_scale > 0.0: + return_layers = {f"{vgg16_layer_no}": "feat"} + self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter( + vgg16_model.features, return_layers=return_layers + ) + self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # Textual Inversion - self.token_replacements = {} + def add_token_replacement(self, target_token_id, rep_token_ids): + self.token_replacements[target_token_id] = rep_token_ids - # CLIP guidance - self.clip_guidance_scale = clip_guidance_scale - self.clip_image_guidance_scale = clip_image_guidance_scale - self.clip_model = clip_model - self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) - self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + def replace_token(self, tokens): + new_tokens = [] + for token in tokens: + if token in self.token_replacements: + new_tokens.extend(self.token_replacements[token]) + else: + new_tokens.append(token) + return new_tokens - # VGG16 guidance - self.vgg16_guidance_scale = vgg16_guidance_scale - if self.vgg16_guidance_scale > 0.0: - return_layers = {f'{vgg16_layer_no}': 'feat'} - self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) - self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets - # ControlNet - self.control_nets: List[ControlNetInfo] = [] + # region xformersとか使う部分:独自に書き換えるので関係なし - # Textual Inversion - def add_token_replacement(self, target_token_id, rep_token_ids): - self.token_replacements[target_token_id] = rep_token_ids + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) - def replace_token(self, tokens): - new_tokens = [] - for token in tokens: - if token in self.token_replacements: - new_tokens.extend(self.token_replacements[token]) - else: - new_tokens.append(token) - return new_tokens + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) - # region xformersとか使う部分:独自に書き換えるので関係なし + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.unet.set_use_memory_efficient_attention_xformers(True) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + # accelerateが必要になるのでとりあえず省略 + raise NotImplementedError("cpu_offload is omitted.") + # if is_accelerate_available(): + # from accelerate import cpu_offload + # else: + # raise ImportError("Please install accelerate via `pip install accelerate`") - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.unet.set_use_memory_efficient_attention_xformers(False) + # device = self.device - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) + # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + # if cpu_offloaded_model is not None: + # cpu_offload(cpu_offloaded_model, device) - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) + # endregion - def enable_sequential_cpu_offload(self): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - # accelerateが必要になるのでとりあえず省略 - raise NotImplementedError("cpu_offload is omitted.") - # if is_accelerate_available(): - # from accelerate import cpu_offload - # else: - # raise ImportError("Please install accelerate via `pip install accelerate`") - - # device = self.device - - # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: - # if cpu_offloaded_model is not None: - # cpu_offload(cpu_offloaded_model, device) -# endregion - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_prompts=None, - clip_guide_images=None, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - num_images_per_prompt = 1 # fixed - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - vae_batch_size = batch_size if vae_batch_size is None else ( - int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_prompts=None, + clip_guide_images=None, + **kwargs, ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + num_images_per_prompt = 1 # fixed - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""]*batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - # CLIP guidanceで使用するembeddingsを取得する - if self.clip_guidance_scale > 0: - clip_text_input = prompt_tokens - if clip_text_input.shape[1] > self.tokenizer.model_max_length: - # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) - clip_text_input = torch.cat([clip_text_input[:, :self.tokenizer.model_max_length-1], - clip_text_input[:, -1].unsqueeze(1)], dim=1) - print("trimmed", clip_text_input.shape) - - for i, clip_prompt in enumerate(clip_prompts): - if clip_prompt is not None: # clip_promptがあれば上書きする - clip_text_input[i] = self.tokenizer(clip_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt",).input_ids.to(self.device) - - text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) - text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - - if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets: - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - - if self.clip_image_guidance_scale > 0: - clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - if len(image_embeddings_clip) == 1: - image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) - elif self.vgg16_guidance_scale > 0: - size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) - clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat'] - if len(image_embeddings_vgg16) == 1: - image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) - else: - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う - pass - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8,) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype,).to(self.device) + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype,) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - timesteps = self.scheduler.timesteps.to(self.device) + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - init_latents = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size] - if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) - init_latents = 0.18215 * init_latents + # get prompt text embeddings - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - if self.control_nets: - noise_pred = original_control_net.call_unet_and_control_net( - i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample - else: - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond - noise_pred = noise_pred_uncond + guidance_scale * \ - (noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond) - - # perform clip guidance - if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: - text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[ - 1] if do_classifier_free_guidance else text_embeddings) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: - noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, - text_embeddings_clip, self.clip_guidance_scale, NUM_CUTOUTS, USE_CUTOUTS,) - if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, - image_embeddings_clip, self.clip_image_guidance_scale, NUM_CUTOUTS, USE_CUTOUTS,) - if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn_vgg16(latents, t, i, text_embeddings_for_guidance, noise_pred, - image_embeddings_vgg16, self.vgg16_guidance_scale) + clip_text_input = prompt_tokens + if clip_text_input.shape[1] > self.tokenizer.model_max_length: + # TODO 75文字を超えたら警告を出す? + print("trim text input", clip_text_input.shape) + clip_text_input = torch.cat( + [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 + ) + print("trimmed", clip_text_input.shape) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + for i, clip_prompt in enumerate(clip_prompts): + if clip_prompt is not None: # clip_promptがあれば上書きする + clip_text_input[i] = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None + if ( + self.clip_image_guidance_scale > 0 + or self.vgg16_guidance_scale > 0 + and clip_guide_images is not None + or self.control_nets + ): + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] - if return_latents: - return (latents, False) + if self.clip_image_guidance_scale > 0: + clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) - latents = 1 / 0.18215 * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents).sample - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample) - image = torch.cat(images) + clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + if len(image_embeddings_clip) == 1: + image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) + elif self.vgg16_guidance_scale > 0: + size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) + clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) - image = (image / 2 + 0.5).clamp(0, 1) + clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"] + if len(image_embeddings_vgg16) == 1: + image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) + else: + # ControlNetのhintにguide imageを流用する + # 前処理はControlNet側で行う + pass - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, - clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), - ) - else: - has_nsfw_concept = None + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] + if init_image is None: + # get the initial random noise unless the user supplied it - # if not return_dict: - return (image, has_nsfw_concept) + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode( + init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.control_nets: + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + ).sample + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # perform clip guidance + if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings + ) + + if self.clip_guidance_scale > 0: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + self.clip_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + image_embeddings_clip, + self.clip_image_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn_vgg16( + latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / 0.18215 * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # if not return_dict: + return (image, has_nsfw_concept) + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - def img2img( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for image-to-image generation. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + def img2img( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for image-to-image generation. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - def inpaint( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for inpaint. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + def inpaint( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for inpaint. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - # CLIP guidance StableDiffusion - # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py + # CLIP guidance StableDiffusion + # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py - # バッチを分解して1件ずつ処理する - def cond_fn(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts=True, ): - if len(latents) == 1: - return self.cond_fn1(latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts) + # バッチを分解して1件ずつ処理する + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + if len(latents) == 1: + return self.cond_fn1( + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings_clip[i].unsqueeze(0) - npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) - noise_pred.append(npr1) - cond_latents.append(cla1) + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings_clip[i].unsqueeze(0) + npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) + noise_pred.append(npr1) + cond_latents.append(cla1) - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents - @torch.enable_grad() - def cond_fn1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts=True, ): - latents = latents.detach().requires_grad_() + @torch.enable_grad() + def cond_fn1( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) - if use_cutouts: - image = self.make_cutouts(image, num_cutouts) - else: - image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) - image = self.normalize(image).to(latents.dtype) + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) + image = self.normalize(image).to(latents.dtype) - image_embeddings_clip = self.clip_model.get_image_features(image) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - if use_cutouts: - dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) - dists = dists.view([num_cutouts, sample.shape[0], -1]) - loss = dists.sum(2).mean(0).sum() * clip_guidance_scale - else: - # バッチサイズが複数だと正しく動くかわからない - loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + # バッチサイズが複数だと正しく動くかわからない + loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale - grads = -torch.autograd.grad(loss, latents)[0] + grads = -torch.autograd.grad(loss, latents)[0] - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents - # バッチを分解して一件ずつ処理する - def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - if len(latents) == 1: - return self.cond_fn_vgg16_b1(latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale) + # バッチを分解して一件ずつ処理する + def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + if len(latents) == 1: + return self.cond_fn_vgg16_b1( + latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale + ) - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings[i].unsqueeze(0) - npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) - noise_pred.append(npr1) - cond_latents.append(cla1) + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings[i].unsqueeze(0) + npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) + noise_pred.append(npr1) + cond_latents.append(cla1) - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents - # 1件だけ処理する - @torch.enable_grad() - def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - latents = latents.detach().requires_grad_() + # 1件だけ処理する + @torch.enable_grad() + def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) - image = self.vgg16_normalize(image).to(latents.dtype) + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) + image = self.vgg16_normalize(image).to(latents.dtype) - image_embeddings = self.vgg16_feat_model(image)['feat'] + image_embeddings = self.vgg16_feat_model(image)["feat"] - # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + # バッチサイズが複数だと正しく動くかわからない + loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので - grads = -torch.autograd.grad(loss, latents)[0] - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents + grads = -torch.autograd.grad(loss, latents)[0] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents class MakeCutouts(torch.nn.Module): - def __init__(self, cut_size, cut_power=1.0): - super().__init__() + def __init__(self, cut_size, cut_power=1.0): + super().__init__() - self.cut_size = cut_size - self.cut_power = cut_power + self.cut_size = cut_size + self.cut_power = cut_power - def forward(self, pixel_values, num_cutouts): - sideY, sideX = pixel_values.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - for _ in range(num_cutouts): - size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] - cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) - return torch.cat(cutouts) + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) def spherical_dist_loss(x, y): - x = torch.nn.functional.normalize(x, dim=-1) - y = torch.nn.functional.normalize(y, dim=-1) - return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + x = torch.nn.functional.normalize(x, dim=-1) + y = torch.nn.functional.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) re_attention = re.compile( @@ -1477,151 +1590,151 @@ re_attention = re.compile( def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ - res = [] - round_brackets = [] - square_brackets = [] + res = [] + round_brackets = [] + square_brackets = [] - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) - if len(res) == 0: - res = [["", 1.0]] + if len(res) == 0: + res = [["", 1.0]] - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 - return res + return res def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] - token = pipe.replace_token(token) + token = pipe.replace_token(token) - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2): min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] - return tokens, weights + return tokens, weights def get_unweighted_text_embeddings( @@ -1633,56 +1746,56 @@ def get_unweighted_text_embeddings( pad: int, no_boseos_middle: Optional[bool] = True, ): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2): (i + 1) * (chunk_length - 2) + 2].clone() + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos - if clip_skip is None or clip_skip == 1: - text_embedding = pipe.text_encoder(text_input_chunk)[0] - else: - enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out['hidden_states'][-clip_skip] - text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = pipe.text_encoder(text_input)[0] + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) else: - enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out['hidden_states'][-clip_skip] - text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings def get_weighted_text_embeddings( @@ -1696,84 +1809,69 @@ def get_weighted_text_embeddings( clip_skip=None, **kwargs, ): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - Args: - pipe (`DiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `1`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) - else: - prompt_tokens = [ - token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids - ] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] - for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + max_length = max(max_length, max([len(token) for token in uncond_tokens])) - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, max_length, bos, eos, @@ -1781,86 +1879,100 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( pipe, - uncond_tokens, + prompt_tokens, pipe.tokenizer.model_max_length, clip_skip, - eos, pad, + eos, + pad, no_boseos_middle=no_boseos_middle, ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings, prompt_tokens - return text_embeddings, None, prompt_tokens + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings, prompt_tokens + return text_embeddings, None, prompt_tokens def preprocess_guide_image(image): - image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 # VGG16の入力は任意サイズでよいので入力画像を適宜リサイズする def preprocess_vgg16_guide_image(image, size): - image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask # endregion @@ -1873,924 +1985,1083 @@ def preprocess_mask(mask): class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] + # バッチ分割が必要なデータ + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt + return_latents: bool + base: BatchDataBase + ext: BatchDataExt def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - use_stable_diffusion_format = os.path.isfile(args.ckpt) - if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) - else: - print("load Diffusers pretrained models") - loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = loading_pipe.text_encoder - vae = loading_pipe.vae - unet = loading_pipe.unet - tokenizer = loading_pipe.tokenizer - del loading_pipe - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") - - if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") - clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) - else: - clip_model = None - - if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") - vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) - else: - vgg16_model = None - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - replace_unet_modules(unet, not args.xformers, args.xformers) - - # tokenizerを読み込む - print("loading tokenizer") - if use_stable_diffusion_format: - tokenizer = train_util.load_tokenizer(args) - - # schedulerを用意する - sched_init_args = {} - scheduler_num_noises_per_step = 1 - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - elif args.sampler == 'lms' or args.sampler == 'k_lms': - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - elif args.sampler == 'euler' or args.sampler == 'k_euler': - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - elif args.sampler == 'euler_a' or args.sampler == 'k_euler_a': - scheduler_cls = EulerAncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args['algorithm_type'] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - elif args.sampler == 'dpm_2' or args.sampler == 'k_dpm_2': - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - elif args.sampler == 'dpm_2_a' or args.sampler == 'k_dpm_2_a': - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - - if args.v_parameterization: - sched_init_args['prediction_type'] = 'v_prediction' - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == 'randn': - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") - scheduler.config.clip_sample = True - - # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - - # custom pipelineをコピったやつを生成する - vae.to(dtype).to(device) - text_encoder.to(dtype).to(device) - unet.to(dtype).to(device) - if clip_model is not None: - clip_model.to(dtype).to(device) - if vgg16_model is not None: - vgg16_model.to(dtype).to(device) - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) - else: - raise ValueError("No weight. Weight is required.") - if network is None: - return - - network.apply_to(text_encoder, unet) - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - networks.append(network) - else: - networks = [] - - # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - - if args.opt_channels_last: - print(f"set optimizing: channels last") - text_encoder.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if clip_model is not None: - clip_model.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - if vgg16_model is not None: - vgg16_model.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) - - pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip, - clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale, - vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer) - pipe.set_control_nets(control_nets) - print("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - embeds = next(iter(data.values())) - if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") - - num_vectors_per_token = embeds.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens = tokenizer.add_tokens(token_strings) - assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(token_ids[0], token_ids) - - token_ids_embeds.append((token_ids, embeds)) - - text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds: - for token_id, embed in zip(token_ids, embeds): - token_embeds[token_id] = embed - - # promptを取得する - if args.from_file is not None: - print(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 else: - paths = glob.glob(os.path.join(path, "*.png")) + glob.glob(os.path.join(path, "*.jpg")) + \ - glob.glob(os.path.join(path, "*.jpeg")) + glob.glob(os.path.join(path, "*.webp")) - paths.sort() + dtype = torch.float32 - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) + highres_fix = args.highres_fix_scale is not None + assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - return images + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, 'filename'): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] - if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") - else: - init_images = None + use_stable_diffusion_format = os.path.isfile(args.ckpt) + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe - if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") - for img in init_images: - if 'prompt' in img.text: - prompt = img.text['prompt'] - if 'negative-prompt' in img.text: - prompt += " --n " + img.text['negative-prompt'] - prompt_list.append(prompt) + # # 置換するCLIPを読み込む + # if args.replace_clip_l14_336: + # text_encoder = load_clip_l14_336(dtype) + # print(f"large clip {CLIP_ID_L14_336} is loaded") - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l + if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: + print("prepare clip model") + clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) + else: + clip_model = None - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l + if args.vgg16_guidance_scale > 0.0: + print("prepare resnet model") + vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) + else: + vgg16_model = None - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - if init_images is not None: - print(f"resize img2img source images to {args.W}*{args.H}") - init_images = resize_images(init_images, (args.W, args.H)) - if mask_images is not None: - print(f"resize img2img mask images to {args.W}*{args.H}") - mask_images = resize_images(mask_images, (args.W, args.H)) + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + replace_unet_modules(unet, not args.xformers, args.xformers) - if networks and mask_images: - # mask を領域情報として流用する、現在は1枚だけ対応 - # TODO 複数のnetwork classの混在時の考慮 - print("use mask as region") - # import cv2 - # for i in range(3): - # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) - # cv2.waitKey() - # cv2.destroyAllWindows() - networks[0].__class__.set_regions(networks, np.array(mask_images[0])) - mask_images = None + # tokenizerを読み込む + print("loading tokenizer") + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) + # schedulerを用意する + sched_init_args = {} + scheduler_num_noises_per_step = 1 + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 - print(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") - guide_images = None - else: - guide_images = None + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" - # seed指定時はseedを決めておく - if args.seed is not None: - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7fffffff) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None + # samplerの乱数をあらかじめ指定するための処理 - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 512 - if args.H is None: - args.H = 512 + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises - for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7fffffff) + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - print("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - width_1st = int(ext.width * args.highres_fix_scale + .5) - height_1st = int(ext.height * args.highres_fix_scale + .5) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 + self.sampler_noise_index += 1 + return noise - ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, - ext.negative_scale, ext.strength, ext.network_muls) - batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) - images_1st = process_batch(batch_1st, True, True) + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager - # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") - if args.highres_fix_latents_upscaling: - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True) - images_1st = images_1st.to(org_dtype) + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - if not args.highres_fix_latents_upscaling: - image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) - # このバッチの情報を取り出す - return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \ - (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) - prompts = [] - negative_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step)] - seeds = [] - clip_prompts = [] + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + print("set clip_sample to True") + scheduler.config.clip_sample = True - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - if mask_image is not None: - mask_images = [] + # custom pipelineをコピったやつを生成する + vae.to(dtype).to(device) + text_encoder.to(dtype).to(device) + unet.to(dtype).to(device) + if clip_model is not None: + clip_model.to(dtype).to(device) + if vgg16_model is not None: + vgg16_model.to(dtype).to(device) + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + networks = [] + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if clip_model is not None: + clip_model.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + if vgg16_model is not None: + vgg16_model.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + args.clip_skip, + clip_model, + args.clip_guidance_scale, + args.clip_image_guidance_scale, + vgg16_model, + args.vgg16_guidance_scale, + args.vgg16_guidance_layer, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + embeds = next(iter(data.values())) + if type(embeds) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + + num_vectors_per_token = embeds.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_ids_embeds.append((token_ids, embeds)) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - - images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, - output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, return_latents=return_latents, - clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] - if highres_1st and not args.highres_fix_save_1st: # return images or latents return images - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(zip(images, prompts, negative_prompts, seeds, clip_prompts)): - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - print("\nType prompt:") - try: - prompt = input() - except EOFError: - break - - valid = len(prompt.strip().split(' --')[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - prompt = prompt_list[prompt_index] - - # parse prompt - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - prompt_args = prompt.strip().split(' --') - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r'w (\d+)', parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue - - m = re.match(r'h (\d+)', parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue - - m = re.match(r's (\d+)', parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue - - m = re.match(r'd ([\d,]+)', parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(',')] - print(f"seeds: {seeds}") - continue - - m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue - - m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == 'none': - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue - - m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue - - m = re.match(r'n (.+)', parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r'c (.+)', parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seeds is not None: - # 数が足りないなら繰り返す - if len(seeds) < args.images_per_prompt: - seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) - seeds = seeds[:args.images_per_prompt] - else: - if predefined_seeds is not None: - seeds = predefined_seeds[-args.images_per_prompt:] - predefined_seeds = predefined_seeds[:-args.images_per_prompt] - elif args.iter_same_seed: - seeds = [iter_seed] * args.images_per_prompt - else: - seeds = [random.randint(0, 0x7fffffff) for _ in range(args.images_per_prompt)] - if args.interactive: - print(f"seed: {seeds}") - - init_image = mask_image = guide_image = None - for seed in seeds: # images_per_promptの数だけ - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # 32単位に丸めたやつにresizeされるので踏襲する - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print(f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます") + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c:p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: - if prev_image is None: - print("Generate 1st image without guide image.") - else: - print("Use previous image as guide image.") - guide_image = prev_image + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) + if mask_images is not None: + print(f"resize img2img mask images to {args.W}*{args.H}") + mask_images = resize_images(mask_images, (args.W, args.H)) - b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() + if networks and mask_images: + # mask を領域情報として流用する、現在は1枚だけ対応 + # TODO 複数のnetwork classの混在時の考慮 + print("use mask as region") + # import cv2 + # for i in range(3): + # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) + # cv2.waitKey() + # cv2.destroyAllWindows() + networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + mask_images = None - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) - global_step += 1 + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + guide_images = None + else: + guide_images = None - prompt_index += 1 + # seed指定時はseedを決めておく + if args.seed is not None: + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 512 + if args.H is None: + args.H = 512 - print("done!") + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + width_1st = int(ext.width * args.highres_fix_scale + 0.5) + height_1st = int(ext.height * args.highres_fix_scale + 0.5) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + ext_1st = BatchDataExt( + width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls + ) + batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + if args.highres_fix_latents_upscaling: + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + if not args.highres_fix_latents_upscaling: + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + (width, height, steps, scale, negative_scale, strength, network_muls), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + )[0] + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + prompt = input() + except EOFError: + break + + valid = len(prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + prompt = prompt_list[prompt_index] + + # parse prompt + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seeds is not None: + # 数が足りないなら繰り返す + if len(seeds) < args.images_per_prompt: + seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) + seeds = seeds[: args.images_per_prompt] + else: + if predefined_seeds is not None: + seeds = predefined_seeds[-args.images_per_prompt :] + predefined_seeds = predefined_seeds[: -args.images_per_prompt] + elif args.iter_same_seed: + seeds = [iter_seed] * args.images_per_prompt + else: + seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if args.interactive: + print(f"seed: {seeds}") + + init_image = mask_image = guide_image = None + for seed in seeds: # images_per_promptの数だけ + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # 32単位に丸めたやつにresizeされるので踏襲する + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: + if prev_image is None: + print("Generate 1st image without guide image.") + else: + print("Use previous image as guide image.") + guide_image = prev_image + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument("--from_file", type=str, default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む") - parser.add_argument("--interactive", action='store_true', help='interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)') - parser.add_argument("--no_preview", action='store_true', help='do not show generated image in interactive mode / 対話モードで画像を表示しない') - parser.add_argument("--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像") - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action='store_true', help="sequential output file name / 生成画像のファイル名を連番にする") - parser.add_argument("--use_original_file_name", action='store_true', - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける") - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument("--vae_batch_size", type=float, default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率") - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument('--sampler', type=str, default='ddim', - choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', - 'dpmsolver++', 'dpmsingle', - 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], - help=f'sampler (scheduler) type / サンプラー(スケジューラ)の種類') - parser.add_argument("--scale", type=float, default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale") - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--tokenizer_cache_dir", type=str, default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument("--seed", type=int, default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") - parser.add_argument("--iter_same_seed", action='store_true', - help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') - parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') - parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') - parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') - parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') - parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannels lastを指定し最適化する') - parser.add_argument("--network_module", type=str, default=None, nargs='*', - help='additional network module to use / 追加ネットワークを使う時そのモジュール名') - parser.add_argument("--network_weights", type=str, default=None, nargs='*', - help='additional network weights to load / 追加ネットワークの重み') - parser.add_argument("--network_mul", type=float, default=None, nargs='*', - help='additional network multiplier / 追加ネットワークの効果の倍率') - parser.add_argument("--network_args", type=str, default=None, nargs='*', - help='additional argmuments for network (key=value) / ネットワークへの追加の引数') - parser.add_argument("--network_show_meta", action='store_true', - help='show metadata of network model / ネットワークモデルのメタデータを表示する') - parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', - help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') - parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') - parser.add_argument("--max_embeddings_multiples", type=int, default=None, - help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') - parser.add_argument("--clip_guidance_scale", type=float, default=0.0, - help='enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)') - parser.add_argument("--clip_image_guidance_scale", type=float, default=0.0, - help='enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する') - parser.add_argument("--vgg16_guidance_scale", type=float, default=0.0, - help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する') - parser.add_argument("--vgg16_guidance_layer", type=int, default=20, - help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)') - parser.add_argument("--guide_image_path", type=str, default=None, nargs="*", - help="image to CLIP guidance / CLIP guided SDでガイドに使う画像") - parser.add_argument("--highres_fix_scale", type=float, default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする") - parser.add_argument("--highres_fix_steps", type=int, default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") - parser.add_argument("--highres_fix_save_1st", action='store_true', - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") - parser.add_argument("--highres_fix_latents_upscaling", action='store_true', - help="use latents upscaling for highres fix / highres fixでlatentで拡大する") - parser.add_argument("--negative_scale", type=float, default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--clip_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)", + ) + parser.add_argument( + "--clip_image_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する", + ) + parser.add_argument( + "--vgg16_guidance_scale", + type=float, + default=0.0, + help="enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する", + ) + parser.add_argument( + "--vgg16_guidance_layer", + type=int, + default=20, + help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) - parser.add_argument("--control_net_models", type=str, default=None, nargs='*', - help='ControlNet models to use / 使用するControlNetのモデル名') - parser.add_argument("--control_net_preps", type=str, default=None, nargs='*', - help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名') - parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み') - parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*', - help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率') + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - main(args) + args = parser.parse_args() + main(args) From e203270e315cfa4b6d786b77fe8dd86b8245aa54 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 24 Mar 2023 20:46:42 +0900 Subject: [PATCH 11/28] support TI embeds trained by WebUI(?) --- gen_img_diffusers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 94ec8179..690d111e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2300,7 +2300,10 @@ def main(args): else: data = torch.load(embeds_file, map_location="cpu") + if "string_to_param" in data: + data = data["string_to_param"] embeds = next(iter(data.values())) + if type(embeds) != torch.Tensor: raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") From 5ec90990de870a4579721db947c2f74b9ce3ed69 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:41:24 +0900 Subject: [PATCH 12/28] =?UTF-8?q?=E3=83=87=E3=83=BC=E3=82=BF=E3=82=BB?= =?UTF-8?q?=E3=83=83=E3=83=88=E3=81=ABepoch=E3=80=81step=E3=81=8C=E9=80=9A?= =?UTF-8?q?=E9=81=94=E3=81=95=E3=82=8C=E3=81=AA=E3=81=84=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 02a2d925..ef10921f 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,18 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + # print("self.current_step:%d"%self.current_step) + # print("dataset_lengh:%d"%len(dataset)) + print("id(self)(collate):%d"%id(self)) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +111,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -186,11 +200,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -498,17 +513,18 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 292cdb8379f00f87e9f1391a0ff2508a3540bd13 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:44:25 +0900 Subject: [PATCH 13/28] =?UTF-8?q?=E3=83=87=E3=83=BC=E3=82=BF=E3=82=BB?= =?UTF-8?q?=E3=83=83=E3=83=88=E3=81=ABepoch=E3=80=81step=E3=81=8C=E9=80=9A?= =?UTF-8?q?=E9=81=94=E3=81=95=E3=82=8C=E3=81=AA=E3=81=84=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 11 ++++++----- library/train_util.py | 1 + train_network.py | 23 +++++++++++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 9c8c90c2..6817d9a3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -499,11 +499,12 @@ def load_user_config(file: str) -> dict: def blueprint_args_conflict(args,blueprint:Blueprint): # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - for b in blueprint.dataset_group.datasets: - for t in b.subsets: - if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - args.persistent_data_loader_workers = False + # for b in blueprint.dataset_group.datasets: + # for t in b.subsets: + # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) + # # args.persistent_data_loader_workers = False + return # for config test if __name__ == "__main__": diff --git a/library/train_util.py b/library/train_util.py index d1df9c58..223e403b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -517,6 +517,7 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) + print(self.current_step, self.max_train_steps, caption) return caption def get_input_ids(self, caption): diff --git a/train_network.py b/train_network.py index 02a2d925..e148a92a 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,15 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +108,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -190,7 +201,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -501,14 +512,14 @@ def train(args): for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 4dc1124f9339af9886f4a73db49e9e7c9cf17a23 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 02:19:55 +0900 Subject: [PATCH 14/28] =?UTF-8?q?lora=E4=BB=A5=E5=A4=96=E3=82=82=E5=AF=BE?= =?UTF-8?q?=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 15 ++++++++------- library/train_util.py | 11 +++++++++++ train_db.py | 15 ++++++++------- train_network.py | 13 +------------ train_textual_inversion.py | 16 +++++++++------- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff580435..96ec3d96 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -6,6 +6,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -21,10 +22,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -65,6 +62,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -188,7 +189,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -259,14 +260,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/train_util.py b/library/train_util.py index 223e403b..994201fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2987,3 +2987,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion + +# colalte_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] \ No newline at end of file diff --git a/train_db.py b/train_db.py index 87fe771b..50d50345 100644 --- a/train_db.py +++ b/train_db.py @@ -8,6 +8,7 @@ import itertools import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -23,10 +24,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) @@ -60,6 +57,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -153,7 +154,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -233,8 +234,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -243,6 +243,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index dd1fb748..79d118d0 100644 --- a/train_network.py +++ b/train_network.py @@ -25,17 +25,6 @@ from library.config_util import ( BlueprintGenerator, ) -class collater_class: - def __init__(self,epoch,step): - self.current_epoch=epoch - self.current_step=step - def __call__(self, examples): - dataset = torch.utils.data.get_worker_info().dataset - dataset.set_current_epoch(self.current_epoch.value) - dataset.set_current_step(self.current_step.value) - return examples[0] - - # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -110,7 +99,7 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = collater_class(current_epoch,current_step) + collater = train_util.collater_class(current_epoch,current_step) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 63b63426..4f2e2724 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -71,10 +72,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -186,6 +183,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -251,7 +252,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -335,13 +336,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 5a3d564a3028057b1d7671b4a570bd37f13aa8d6 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 02:26:08 +0900 Subject: [PATCH 15/28] =?UTF-8?q?print=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 994201fc..1ba38adc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -517,7 +517,6 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) - print(self.current_step, self.max_train_steps, caption) return caption def get_input_ids(self, caption): From a4b34a9c3ce36c93ae12d5b4b98ea9f93e1606a3 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 03:26:55 +0900 Subject: [PATCH 16/28] =?UTF-8?q?blueprint=5Fargs=5Fconflict=E3=81=AF?= =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA=E3=81=9F=E3=82=81=E5=89=8A=E9=99=A4?= =?UTF-8?q?=E3=80=81shuffle=E3=81=8C=E6=AF=8E=E5=9B=9E=E8=A1=8C=E3=82=8F?= =?UTF-8?q?=E3=82=8C=E3=82=8B=E4=B8=8D=E5=85=B7=E5=90=88=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 1 - library/config_util.py | 9 --------- library/train_util.py | 3 ++- train_db.py | 1 - train_network.py | 1 - train_textual_inversion.py | 1 - 6 files changed, 2 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 96ec3d96..f387179a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/library/config_util.py b/library/config_util.py index 6817d9a3..b1543f63 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,15 +497,6 @@ def load_user_config(file: str) -> dict: return config -def blueprint_args_conflict(args,blueprint:Blueprint): - # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - # for b in blueprint.dataset_group.datasets: - # for t in b.subsets: - # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - # # args.persistent_data_loader_workers = False - return - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py index 1ba38adc..af9637de 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -437,8 +437,9 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() def set_current_step(self, step): self.current_step = step diff --git a/train_db.py b/train_db.py index 50d50345..3a3d2df8 100644 --- a/train_db.py +++ b/train_db.py @@ -54,7 +54,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_network.py b/train_network.py index 79d118d0..a7dbd374 100644 --- a/train_network.py +++ b/train_network.py @@ -94,7 +94,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4f2e2724..149308b4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -180,7 +180,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) From 4c06bfad60be71f6f13f2d14413694c2e0be7813 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Sun, 26 Mar 2023 00:01:29 +0000 Subject: [PATCH 17/28] Fix for TypeError from bf16 precision: Thanks to mgz-dev --- library/custom_train_functions.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b080b40c..fd4f6156 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,21 +1,18 @@ import torch import argparse -import numpy as np - def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): - alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() - sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 - all_snr.to(loss.device) snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper loss = loss * snr_weight return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") From c9b157b536678015e42fbaa11aee0b0cdefc97e5 Mon Sep 17 00:00:00 2001 From: mgz-dev <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 25 Mar 2023 19:56:14 -0500 Subject: [PATCH 18/28] update resize_lora.py (fix out of bounds and index) Fix error where index may go out of bounds when using certain dynamic parameters. Fix index and rank issue (previously some parts of code was incorrectly using python index position rather than rank, which is -1 dim). --- networks/resize_lora.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 2bd86599..7b740634 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -11,6 +11,8 @@ import numpy as np MIN_SV = 1e-6 +# Model save and load functions + def load_state_dict(file_name, dtype): if model_util.is_safetensors(file_name): sd = load_file(file_name) @@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) +# Indexing functions + def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0)/original_sum index = int(torch.searchsorted(cumulative_sums, target)) + 1 - if index >= len(S): - index = len(S) - 1 + index = max(1, min(index, len(S)-1)) return index @@ -54,8 +57,16 @@ def index_sv_fro(S, target): s_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - if index >= len(S): - index = len(S) - 1 + index = max(1, min(index, len(S)-1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv/target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S)-1)) return index @@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device): return weight +# Calculate new rank + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} if dynamic_method=="sv_ratio": # Calculate new dim and alpha based off ratio - max_sv = S[0] - min_sv = max_sv/dynamic_param - new_rank = max(torch.sum(S > min_sv).item(),1) + new_rank = index_sv_ratio(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) elif dynamic_method=="sv_cumulative": # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) - new_rank = max(new_rank, 1) + new_rank = index_sv_cumulative(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) elif dynamic_method=="sv_fro": # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) - new_rank = min(max(new_rank, 1), len(S)-1) + new_rank = index_sv_fro(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) else: new_rank = rank @@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict["new_alpha"] = new_alpha param_dict["sum_retained"] = (s_rank)/s_sum param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0]/S[new_rank] + param_dict["max_ratio"] = S[0]/S[new_rank - 1] return param_dict From 14891523ced05bd442f10996360be405c827aae8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Mar 2023 22:17:03 +0900 Subject: [PATCH 19/28] fix seed for each dataset to make shuffling same --- library/config_util.py | 4 +++ library/train_util.py | 55 +++++++++++++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index b1543f63..97bbb4a8 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -4,6 +4,7 @@ from dataclasses import ( dataclass, ) import functools +import random from textwrap import dedent, indent import json from pathlib import Path @@ -428,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu print(info) # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): print(f"[Dataset {i}]") dataset.make_buckets() + dataset.set_seed(seed) return DatasetGroup(datasets) diff --git a/library/train_util.py b/library/train_util.py index 6a5679d3..2d93b126 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -277,7 +277,7 @@ class BaseSubset: caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, token_warmup_min: int, - token_warmup_step: Union[float,int], + token_warmup_step: Union[float, int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -419,6 +419,7 @@ class BaseDataset(torch.utils.data.Dataset): self.current_step: int = 0 self.max_train_steps: int = 0 + self.seed: int = 0 # augmentation self.aug_helper = AugHelper() @@ -435,8 +436,11 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_seed(self, seed): + self.seed = seed + def set_current_epoch(self, epoch): - if not self.current_epoch == epoch: + if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする self.shuffle_buckets() self.current_epoch = epoch @@ -476,12 +480,15 @@ class BaseDataset(torch.utils.data.Dataset): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + print(subset.token_warmup_min, subset.token_warmup_step) if subset.token_warmup_step < 1: subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens_len = ( + math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) tokens = tokens[:tokens_len] def dropout_tags(tokens): @@ -667,6 +674,9 @@ class BaseDataset(torch.utils.data.Dataset): self._length = len(self.buckets_indices) def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() @@ -1073,7 +1083,7 @@ class DreamBoothDataset(BaseDataset): self.register_image(info, subset) n += info.num_repeats else: - info.num_repeats += 1 + info.num_repeats += 1 # rewrite registered info n += 1 if n >= num_train_images: break @@ -1134,6 +1144,8 @@ class FineTuningDataset(BaseDataset): # path情報を作る if os.path.exists(image_key): abs_path = image_key + elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: npz_path = os.path.join(subset.image_dir, image_key + ".npz") if os.path.exists(npz_path): @@ -1330,9 +1342,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") + print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します") + + epoch = 1 + steps = 1 + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) - train_dataset.set_current_epoch(1) k = 0 indices = list(range(len(train_dataset))) random.shuffle(indices) @@ -1358,6 +1374,15 @@ def debug_dataset(train_dataset, show_input_ids=False): cv2.destroyAllWindows() if k == 27: break + if k == ord("e"): + epoch += 1 + steps = len(train_dataset) * (epoch - 1) + train_dataset.set_current_epoch(epoch) + print(f"epoch: {epoch}") + + steps += 1 + train_dataset.set_current_step(steps) + if k == 27 or (example["images"] is None and i >= 8): break @@ -2001,7 +2026,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: @@ -2089,7 +2114,7 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) - + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに @@ -3025,13 +3050,15 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion -# colalte_fn用 epoch,stepはmultiprocessing.Value + +# collate_fn用 epoch,stepはmultiprocessing.Value class collater_class: - def __init__(self,epoch,step): - self.current_epoch=epoch - self.current_step=step + def __init__(self, epoch, step): + self.current_epoch = epoch + self.current_step = step + def __call__(self, examples): dataset = torch.utils.data.get_worker_info().dataset dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] From 066b1bb57e58603ce21acb6c3c7aaddc19338153 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 20:47:11 +0900 Subject: [PATCH 20/28] fix do not mean in batch dim when min_snr_gamma --- fine_tune.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 8ae1bb29..0f42741b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,8 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) @@ -62,9 +63,9 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + collater = train_util.collater_class(current_epoch, current_step) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) @@ -196,7 +197,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -260,7 +263,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 for m in training_models: m.train() @@ -308,10 +311,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + # do not mean over batch dimension for snr weight + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -406,7 +413,6 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") From 43a08b406164c1a902d896eeefe8d0c9c9aec41c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 20:47:27 +0900 Subject: [PATCH 21/28] add ja comment --- library/custom_train_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index fd4f6156..dde0bdd4 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -15,4 +15,4 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") From 238f01bc9c02f70588dd522df31d0282be492ec4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 20:48:21 +0900 Subject: [PATCH 22/28] fix images are used twice, update debug dataset --- library/train_util.py | 92 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2d93b126..55b5101b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -481,8 +481,7 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: tokens = [t.strip() for t in caption.strip().split(",")] - print(subset.token_warmup_min, subset.token_warmup_step) - if subset.token_warmup_step < 1: + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( @@ -1342,50 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します") + print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") epoch = 1 - steps = 1 - train_dataset.set_current_epoch(epoch) - train_dataset.set_current_step(steps) + while True: + print(f"epoch: {epoch}") - k = 0 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - for i, idx in enumerate(indices): - example = train_dataset[idx] - if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate( - zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) - ): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') - if show_input_ids: - print(f"input ids: {iid}") - if example["images"] is not None: - im = example["images"][j] - print(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - if os.name == "nt": # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == ord("e"): - epoch += 1 - steps = len(train_dataset) * (epoch - 1) - train_dataset.set_current_epoch(epoch) - print(f"epoch: {epoch}") - - steps += 1 - train_dataset.set_current_step(steps) + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) - if k == 27 or (example["images"] is None and i >= 8): + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: break + epoch += 1 + def glob_images(directory, base="*"): img_paths = [] @@ -1394,8 +1398,8 @@ def glob_images(directory, base="*"): img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() return img_paths @@ -1407,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive): else: for ext in IMAGE_EXTENSIONS: image_paths += list(dir_path.glob("*" + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() return image_paths From 895b0b6ca73ad51efc6d1043ba60d3b5fd91a009 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 21:22:32 +0900 Subject: [PATCH 23/28] Fix saving issue if epoch/step not in checkpoint --- library/model_util.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index d1020c05..3d8e7539 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p key_count = len(state_dict.keys()) new_ckpt = {'state_dict': state_dict} - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] + # epoch and global_step are sometimes not int + try: + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] + except: + pass new_ckpt['epoch'] = epochs new_ckpt['global_step'] = steps From 5fa20b53481abf26e035b6fc3ec26f597944706d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 21:37:10 +0900 Subject: [PATCH 24/28] update readme --- README.md | 98 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index daf72c6e..cd1de7b9 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,39 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 27 Mar. 2023, 2023/3/27: + - Fix issues when `--persistent_data_loader_workers` is specified. + - The batch members of the bucket are not shuffled. + - `--caption_dropout_every_n_epochs` does not work. + - These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue. + - Fix an issue that images are loaded twice in Windows environment. + - Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work! + - Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper. + + - Add tag warmup. Details are in [#322](https://github.com/kohya-ss/sd-scripts/pull/322). Thanks to u-haru! + - Add `token_warmup_min` and `token_warmup_step` to dataset settings. + - Gradually increase the number of tokens from `token_warmup_min` to `token_warmup_step`. + - For example, if `token_warmup_min` is `3` and `token_warmup_step` is `10`, the first step will use the first 3 tokens, and the 10th step will use all tokens. + - Fix a bug in `resize_lora.py`. Thanks to mgz-dev! [#328](https://github.com/kohya-ss/sd-scripts/pull/328) + - Add `--debug_dataset` option to step to the next step with `S` key and to the next epoch with `E` key. + - Fix other bugs. + + - `--persistent_data_loader_workers` を指定した時の各種不具合を修正しました。 + - `--caption_dropout_every_n_epochs` が効かない。 + - バケットのバッチメンバーがシャッフルされない。 + - エポックの遷移が正しく認識されないために発生していました。ご指摘いただいたu-haru氏に感謝します。 + - Windows環境で画像が二重に読み込まれる不具合を修正しました。 + - Min-SNR Weighting strategyを追加しました。 詳細は [#308](https://github.com/kohya-ss/sd-scripts/pull/308) をご参照ください。AI-Casanova氏の素晴らしい貢献に感謝します。 + - `--min_snr_gamma` オプションを学習スクリプトに追加しました。論文では5が推奨されています。 + - タグのウォームアップを追加しました。詳細は [#322](https://github.com/kohya-ss/sd-scripts/pull/322) をご参照ください。u-haru氏に感謝します。 + - データセット設定に `token_warmup_min` と `token_warmup_step` を追加しました。 + - `token_warmup_min` で指定した数のトークン(カンマ区切りの文字列)から、`token_warmup_step` で指定したステップまで、段階的にトークンを増やしていきます。 + - たとえば `token_warmup_min`に `3` を、`token_warmup_step` に `10` を指定すると、最初のステップでは最初から3個のトークンが使われ、10ステップ目では全てのトークンが使われます。 + - `resize_lora.py` の不具合を修正しました。mgz-dev氏に感謝します。[#328](https://github.com/kohya-ss/sd-scripts/pull/328) + - `--debug_dataset` オプションで、`S`キーで次のステップへ、`E`キーで次のエポックへ進めるようにしました。 + - その他の不具合を修正しました。 + + - 21 Mar. 2023, 2023/3/21: - Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls. - Please start with`2` or `4` depending on the size of VRAM. @@ -143,50 +176,49 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。 - `resize_lora.py` を dynamic rank (rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含む)の時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。 +## Sample image generation during traiing + A prompt file might look like this, for example - - Sample image generation: - A prompt file might look like this, for example +``` +# prompt 1 +masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 - ``` - # prompt 1 - masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 +# prompt 2 +masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 +``` - # prompt 2 - masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 - ``` + Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. + * `--n` Negative prompt up to the next option. + * `--w` Specifies the width of the generated image. + * `--h` Specifies the height of the generated image. + * `--d` Specifies the seed of the generated image. + * `--l` Specifies the CFG scale of the generated image. + * `--s` Specifies the number of steps in the generation. - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. + The prompt weighting such as `( )` and `[ ]` are working. - The prompt weighting such as `( )` and `[ ]` are working. +## サンプル画像生成 +プロンプトファイルは例えば以下のようになります。 - - サンプル画像生成: - プロンプトファイルは例えば以下のようになります。 +``` +# prompt 1 +masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 - ``` - # prompt 1 - masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 +# prompt 2 +masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 +``` - # prompt 2 - masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 - ``` + `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 - `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 + * `--n` Negative prompt up to the next option. + * `--w` Specifies the width of the generated image. + * `--h` Specifies the height of the generated image. + * `--d` Specifies the seed of the generated image. + * `--l` Specifies the CFG scale of the generated image. + * `--s` Specifies the number of steps in the generation. - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. - - `( )` や `[ ]` などの重みづけは動作しません。 + `( )` や `[ ]` などの重みづけも動作します。 Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 From 99eaf1fd65c2f51a9491a9a15a56587d6ea71f5b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Mar 2023 21:38:01 +0900 Subject: [PATCH 25/28] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd1de7b9..2f83a66c 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。 - `resize_lora.py` を dynamic rank (rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含む)の時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。 -## Sample image generation during traiing +## Sample image generation during training A prompt file might look like this, for example ``` From 0138a917d8dc8b238ad9c74581e6ee27489da035 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Tue, 28 Mar 2023 08:43:41 +0900 Subject: [PATCH 26/28] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 2f83a66c..613025be 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- __There is an issue with the training script crashing when `max_data_loader_n_workers` is 0. Please temporarily set it to a value greater than 0.__ +- __現在、 `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合があります。一時的に1以上の値を設定してください。__ - 27 Mar. 2023, 2023/3/27: - Fix issues when `--persistent_data_loader_workers` is specified. - The batch members of the bucket are not shuffled. From 4f70e5dca6c0d922b6009f2839a5f2341dc8659f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 28 Mar 2023 19:42:47 +0900 Subject: [PATCH 27/28] fix to work with num_workers=0 --- fine_tune.py | 3 ++- library/train_util.py | 12 ++++++++++-- train_db.py | 19 +++++++++++-------- train_network.py | 3 ++- train_textual_inversion.py | 3 ++- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0f42741b..637a729a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -65,7 +65,8 @@ def train(args): current_epoch = Value("i", 0) current_step = Value("i", 0) - collater = train_util.collater_class(current_epoch, current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/library/train_util.py b/library/train_util.py index 55b5101b..e1a8e922 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3057,12 +3057,20 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # collate_fn用 epoch,stepはmultiprocessing.Value class collater_class: - def __init__(self, epoch, step): + def __init__(self, epoch, step, dataset): self.current_epoch = epoch self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing def __call__(self, examples): - dataset = torch.utils.data.get_worker_info().dataset + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] diff --git a/train_db.py b/train_db.py index f441d5d6..b3eead94 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,8 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) @@ -57,9 +58,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -161,7 +163,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -234,7 +238,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -298,8 +302,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 20ad2c4d..423649ee 100644 --- a/train_network.py +++ b/train_network.py @@ -101,7 +101,8 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 681bc628..f279370a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -186,7 +186,8 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: From 472f516e7c3a0223e702b5e178cdae8e99e31f1b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 28 Mar 2023 19:44:43 +0900 Subject: [PATCH 28/28] update readme --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 613025be..c63b7de0 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -- __There is an issue with the training script crashing when `max_data_loader_n_workers` is 0. Please temporarily set it to a value greater than 0.__ -- __現在、 `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合があります。一時的に1以上の値を設定してください。__ +- 28 Mar. 2023, 2023/3/28: + - Fix an issue that the training script crashes when `max_data_loader_n_workers` is 0. + - `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合を修正しました。 + - 27 Mar. 2023, 2023/3/27: - Fix issues when `--persistent_data_loader_workers` is specified. - The batch members of the bucket are not shuffled.