Add new samplers

This commit is contained in:
shigabeev
2024-09-09 07:14:29 +00:00
parent cf67f84708
commit cb8e5e6a69
2 changed files with 19 additions and 5 deletions

View File

@@ -31,6 +31,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
DPMSolverSDEScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
@@ -1628,7 +1629,18 @@ def main(args):
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2
has_clip_sample = False
elif args.sampler == "k_dpm++_sde":
scheduler_cls = DPMSolverSDEScheduler
sched_init_args["algorithm_type"] = "dpmsolver++"
sched_init_args["use_karras_sigmas"] = True
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_sde
elif args.sampler == "k_dpm++_2m_sde":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = "dpmsolver++"
sched_init_args["use_karras_sigmas"] = True
sched_init_args["solver_order"] = 2
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
@@ -3041,6 +3053,8 @@ def setup_parser() -> argparse.ArgumentParser:
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"k_dpm++_sde",
"k_dpm++_2m_sde"
],
help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類",
)

View File

@@ -2324,12 +2324,12 @@ def main(args):
scheduler_cls = KDPM2AncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2
elif args.sampler == "dpmpp_sde":
elif args.sampler == "k_dpm++_sde":
scheduler_cls = DPMSolverSDEScheduler
sched_init_args["algorithm_type"] = "dpmsolver++"
sched_init_args["use_karras_sigmas"] = True
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_sde
elif args.sampler == "dpmpp_2m_sde":
elif args.sampler == "k_dpm++_2m_sde":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = "dpmsolver++"
sched_init_args["use_karras_sigmas"] = True
@@ -3580,8 +3580,8 @@ def setup_parser() -> argparse.ArgumentParser:
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"dpmpp_sde",
"dpmpp_2m_sde",
"k_dpm++_sde",
"k_dpm++_2m_sde"
],
help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類",
)