mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Make a function get_my_scheduler()
This commit is contained in:
@@ -4381,6 +4381,59 @@ SCHEDULER_LINEAR_END = 0.0120
|
|||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
def get_my_scheduler(
|
||||||
|
*,
|
||||||
|
sample_sampler: str,
|
||||||
|
v_parameterization: bool,
|
||||||
|
):
|
||||||
|
sched_init_args = {}
|
||||||
|
if sample_sampler == "ddim":
|
||||||
|
scheduler_cls = DDIMScheduler
|
||||||
|
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||||
|
scheduler_cls = DDPMScheduler
|
||||||
|
elif sample_sampler == "pndm":
|
||||||
|
scheduler_cls = PNDMScheduler
|
||||||
|
elif sample_sampler == "lms" or sample_sampler == "k_lms":
|
||||||
|
scheduler_cls = LMSDiscreteScheduler
|
||||||
|
elif sample_sampler == "euler" or sample_sampler == "k_euler":
|
||||||
|
scheduler_cls = EulerDiscreteScheduler
|
||||||
|
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
|
||||||
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||||
|
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
|
||||||
|
scheduler_cls = DPMSolverMultistepScheduler
|
||||||
|
sched_init_args["algorithm_type"] = sample_sampler
|
||||||
|
elif sample_sampler == "dpmsingle":
|
||||||
|
scheduler_cls = DPMSolverSinglestepScheduler
|
||||||
|
elif sample_sampler == "heun":
|
||||||
|
scheduler_cls = HeunDiscreteScheduler
|
||||||
|
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
|
||||||
|
scheduler_cls = KDPM2DiscreteScheduler
|
||||||
|
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
|
||||||
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||||
|
else:
|
||||||
|
scheduler_cls = DDIMScheduler
|
||||||
|
|
||||||
|
if v_parameterization:
|
||||||
|
sched_init_args["prediction_type"] = "v_prediction"
|
||||||
|
|
||||||
|
scheduler = scheduler_cls(
|
||||||
|
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||||
|
beta_start=SCHEDULER_LINEAR_START,
|
||||||
|
beta_end=SCHEDULER_LINEAR_END,
|
||||||
|
beta_schedule=SCHEDLER_SCHEDULE,
|
||||||
|
**sched_init_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clip_sample=Trueにする
|
||||||
|
if (
|
||||||
|
hasattr(scheduler.config, "clip_sample")
|
||||||
|
and scheduler.config.clip_sample is False
|
||||||
|
):
|
||||||
|
# print("set clip_sample to True")
|
||||||
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
def sample_images(*args, **kwargs):
|
def sample_images(*args, **kwargs):
|
||||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
@@ -4438,50 +4491,11 @@ def sample_images_common(
|
|||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
prompts = json.load(f)
|
prompts = json.load(f)
|
||||||
|
|
||||||
# schedulerを用意する
|
scheduler = get_my_scheduler(
|
||||||
sched_init_args = {}
|
sample_sampler=args.scheduler,
|
||||||
if args.sample_sampler == "ddim":
|
v_parameterization=args.v_parameterization,
|
||||||
scheduler_cls = DDIMScheduler
|
|
||||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
|
||||||
scheduler_cls = DDPMScheduler
|
|
||||||
elif args.sample_sampler == "pndm":
|
|
||||||
scheduler_cls = PNDMScheduler
|
|
||||||
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
|
||||||
scheduler_cls = LMSDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
|
||||||
scheduler_cls = EulerDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
|
||||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
|
||||||
scheduler_cls = DPMSolverMultistepScheduler
|
|
||||||
sched_init_args["algorithm_type"] = args.sample_sampler
|
|
||||||
elif args.sample_sampler == "dpmsingle":
|
|
||||||
scheduler_cls = DPMSolverSinglestepScheduler
|
|
||||||
elif args.sample_sampler == "heun":
|
|
||||||
scheduler_cls = HeunDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
|
||||||
scheduler_cls = KDPM2DiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
|
||||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
|
||||||
else:
|
|
||||||
scheduler_cls = DDIMScheduler
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
sched_init_args["prediction_type"] = "v_prediction"
|
|
||||||
|
|
||||||
scheduler = scheduler_cls(
|
|
||||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
|
||||||
beta_start=SCHEDULER_LINEAR_START,
|
|
||||||
beta_end=SCHEDULER_LINEAR_END,
|
|
||||||
beta_schedule=SCHEDLER_SCHEDULE,
|
|
||||||
**sched_init_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# clip_sample=Trueにする
|
|
||||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
|
||||||
# print("set clip_sample to True")
|
|
||||||
scheduler.config.clip_sample = True
|
|
||||||
|
|
||||||
pipeline = pipe_class(
|
pipeline = pipe_class(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
|
|||||||
Reference in New Issue
Block a user