diff --git a/gen_img.py b/gen_img.py index 4fe89871..29d0f328 100644 --- a/gen_img.py +++ b/gen_img.py @@ -31,6 +31,7 @@ from diffusers import ( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + DPMSolverSDEScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, @@ -1628,7 +1629,18 @@ def main(args): scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_num_noises_per_step = 2 has_clip_sample = False - + elif args.sampler == "k_dpm++_sde": + scheduler_cls = DPMSolverSDEScheduler + sched_init_args["algorithm_type"] = "dpmsolver++" + sched_init_args["use_karras_sigmas"] = True + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_sde + elif args.sampler == "k_dpm++_2m_sde": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "dpmsolver++" + sched_init_args["use_karras_sigmas"] = True + sched_init_args["solver_order"] = 2 + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + if args.v_parameterization: sched_init_args["prediction_type"] = "v_prediction" @@ -3041,6 +3053,8 @@ def setup_parser() -> argparse.ArgumentParser: "k_euler_a", "k_dpm_2", "k_dpm_2_a", + "k_dpm++_sde", + "k_dpm++_2m_sde" ], help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", ) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 63d5c138..9d91d1b7 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2324,12 +2324,12 @@ def main(args): scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_num_noises_per_step = 2 - elif args.sampler == "dpmpp_sde": + elif args.sampler == "k_dpm++_sde": scheduler_cls = DPMSolverSDEScheduler sched_init_args["algorithm_type"] = "dpmsolver++" sched_init_args["use_karras_sigmas"] = True scheduler_module = diffusers.schedulers.scheduling_dpmsolver_sde - elif args.sampler == "dpmpp_2m_sde": + elif args.sampler == "k_dpm++_2m_sde": scheduler_cls = DPMSolverMultistepScheduler sched_init_args["algorithm_type"] = "dpmsolver++" sched_init_args["use_karras_sigmas"] = True @@ -3580,8 +3580,8 @@ def setup_parser() -> argparse.ArgumentParser: "k_euler_a", "k_dpm_2", "k_dpm_2_a", - "dpmpp_sde", - "dpmpp_2m_sde", + "k_dpm++_sde", + "k_dpm++_2m_sde" ], help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", )