mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Added --sample_at_first to generate sample images before training
This commit is contained in:
@@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_every_n_epochs",
|
"--sample_every_n_epochs",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -4429,15 +4432,19 @@ def sample_images_common(
|
|||||||
"""
|
"""
|
||||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
"""
|
"""
|
||||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
if steps == 0:
|
||||||
return
|
if not args.sample_at_first:
|
||||||
if args.sample_every_n_epochs is not None:
|
|
||||||
# sample_every_n_steps は無視する
|
|
||||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||||
return
|
return
|
||||||
|
if args.sample_every_n_epochs is not None:
|
||||||
|
# sample_every_n_steps は無視する
|
||||||
|
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||||
|
return
|
||||||
|
|
||||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||||
if not os.path.isfile(args.sample_prompts):
|
if not os.path.isfile(args.sample_prompts):
|
||||||
|
|||||||
@@ -477,6 +477,19 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
# For --sample_at_first
|
||||||
|
sdxl_train_util.sample_images(
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
[tokenizer1, tokenizer2],
|
||||||
|
[text_encoder1, text_encoder2],
|
||||||
|
unet,
|
||||||
|
)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user