From fea810b437e0b4ea448e0ffa7d5933437bac6cae Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:44:57 +0900 Subject: [PATCH] Added --sample_at_first to generate sample images before training --- library/train_util.py | 19 +++++++++++++------ sdxl_train.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f503341..926e956c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) + parser.add_argument( + "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" + ) parser.add_argument( "--sample_every_n_epochs", type=int, @@ -4429,15 +4432,19 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): diff --git a/sdxl_train.py b/sdxl_train.py index 47bc6a42..a25da42d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -477,6 +477,19 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + for m in training_models: m.train()