mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add feature to sample images during sdxl training
This commit is contained in:
@@ -3695,7 +3695,12 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def sample_images(
|
||||
def sample_images(*args, **kwargs):
|
||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
@@ -3790,7 +3795,7 @@ def sample_images(
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
@@ -3801,9 +3806,8 @@ def sample_images(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.to(device)
|
||||
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user