Added --sample_at_first to generate sample images before training

This commit is contained in:
Yuta Hayashibe
2023-10-29 21:44:57 +09:00
parent 96d877be90
commit fea810b437
2 changed files with 26 additions and 6 deletions

View File

@@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument( parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
) )
parser.add_argument(
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
)
parser.add_argument( parser.add_argument(
"--sample_every_n_epochs", "--sample_every_n_epochs",
type=int, type=int,
@@ -4429,15 +4432,19 @@ def sample_images_common(
""" """
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
""" """
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: if steps == 0:
return if not args.sample_at_first:
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return return
else: else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts): if not os.path.isfile(args.sample_prompts):

View File

@@ -477,6 +477,19 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
for m in training_models: for m in training_models:
m.train() m.train()