mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Added --sample_at_first to generate sample images before training
This commit is contained in:
@@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_every_n_epochs",
|
||||
type=int,
|
||||
@@ -4429,6 +4432,10 @@ def sample_images_common(
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
|
||||
@@ -477,6 +477,19 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user