diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c40f1a0..63d5c138 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -76,6 +76,7 @@ from diffusers import ( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + DPMSolverSDEScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, @@ -2323,7 +2324,18 @@ def main(args): scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_num_noises_per_step = 2 - + elif args.sampler == "dpmpp_sde": + scheduler_cls = DPMSolverSDEScheduler + sched_init_args["algorithm_type"] = "dpmsolver++" + sched_init_args["use_karras_sigmas"] = True + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_sde + elif args.sampler == "dpmpp_2m_sde": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = "dpmsolver++" + sched_init_args["use_karras_sigmas"] = True + sched_init_args["solver_order"] = 2 + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + if args.v_parameterization: sched_init_args["prediction_type"] = "v_prediction" @@ -3568,6 +3580,8 @@ def setup_parser() -> argparse.ArgumentParser: "k_euler_a", "k_dpm_2", "k_dpm_2_a", + "dpmpp_sde", + "dpmpp_2m_sde", ], help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", )