add feature to sample images during sdxl training

This commit is contained in:
Kohya S
2023-07-02 16:42:19 +09:00
parent 227a62e4c4
commit 64cf922841
5 changed files with 1402 additions and 32 deletions

View File

@@ -3695,7 +3695,12 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images(
def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def sample_images_common(
pipe_class,
accelerator,
args: argparse.Namespace,
epoch,
@@ -3790,7 +3795,7 @@ def sample_images(
# print("set clip_sample to True")
scheduler.config.clip_sample = True
pipeline = StableDiffusionLongPromptWeightingPipeline(
pipeline = pipe_class(
text_encoder=text_encoder,
vae=vae,
unet=unet,
@@ -3801,9 +3806,8 @@ def sample_images(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.to(device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)