diff --git a/library/train_util.py b/library/train_util.py index d8577b9d..3dc33f07 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6217,6 +6217,32 @@ def get_my_scheduler( elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsolver++_2m": + scheduler_cls = DPMSolverMultistepScheduler + elif sample_sampler == "dpmsolver++_2m_lu": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["use_lu_lambdas"] = True + elif sample_sampler == "dpmsolver++_2m_k": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["use_karras_sigmas"] = True + elif sample_sampler == "dpmsolver++_2m_stable": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["euler_at_final"] = True + elif sample_sampler == "dpmsolver++_2m_sde": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "sde-dpmsolver++" + elif sample_sampler == "dpmsolver++_2m_sde_k": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "sde-dpmsolver++" + sched_init_args["use_karras_sigmas"] = True + elif sample_sampler == "dpmsolver++_2m_sde_lu": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "sde-dpmsolver++" + sched_init_args["use_lu_lambdas"] = True + elif sample_sampler == "dpmsolver++_2m_sde_stable": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "sde-dpmsolver++" + sched_init_args["euler_at_final"] = True elif sample_sampler == "dpmsingle": scheduler_cls = DPMSolverSinglestepScheduler elif sample_sampler == "heun":