mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
suppress waning for scheduler args #748
This commit is contained in:
@@ -1309,7 +1309,10 @@ def main(args):
|
|||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
|
has_steps_offset = True
|
||||||
|
has_clip_sample = True
|
||||||
scheduler_num_noises_per_step = 1
|
scheduler_num_noises_per_step = 1
|
||||||
|
|
||||||
if args.sampler == "ddim":
|
if args.sampler == "ddim":
|
||||||
scheduler_cls = DDIMScheduler
|
scheduler_cls = DDIMScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_ddim
|
scheduler_module = diffusers.schedulers.scheduling_ddim
|
||||||
@@ -1319,32 +1322,48 @@ def main(args):
|
|||||||
elif args.sampler == "pndm":
|
elif args.sampler == "pndm":
|
||||||
scheduler_cls = PNDMScheduler
|
scheduler_cls = PNDMScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_pndm
|
scheduler_module = diffusers.schedulers.scheduling_pndm
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "lms" or args.sampler == "k_lms":
|
elif args.sampler == "lms" or args.sampler == "k_lms":
|
||||||
scheduler_cls = LMSDiscreteScheduler
|
scheduler_cls = LMSDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
|
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "euler" or args.sampler == "k_euler":
|
elif args.sampler == "euler" or args.sampler == "k_euler":
|
||||||
scheduler_cls = EulerDiscreteScheduler
|
scheduler_cls = EulerDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
||||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
||||||
scheduler_cls = DPMSolverMultistepScheduler
|
scheduler_cls = DPMSolverMultistepScheduler
|
||||||
sched_init_args["algorithm_type"] = args.sampler
|
sched_init_args["algorithm_type"] = args.sampler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
|
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpmsingle":
|
elif args.sampler == "dpmsingle":
|
||||||
scheduler_cls = DPMSolverSinglestepScheduler
|
scheduler_cls = DPMSolverSinglestepScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
|
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
|
||||||
|
has_clip_sample = False
|
||||||
|
has_steps_offset = False
|
||||||
elif args.sampler == "heun":
|
elif args.sampler == "heun":
|
||||||
scheduler_cls = HeunDiscreteScheduler
|
scheduler_cls = HeunDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
|
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
|
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
|
||||||
scheduler_cls = KDPM2DiscreteScheduler
|
scheduler_cls = KDPM2DiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
|
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
|
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
|
||||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
|
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
|
||||||
scheduler_num_noises_per_step = 2
|
scheduler_num_noises_per_step = 2
|
||||||
|
has_clip_sample = False
|
||||||
|
|
||||||
|
# 警告を出さないようにする
|
||||||
|
if has_steps_offset:
|
||||||
|
sched_init_args["steps_offset"] = 1
|
||||||
|
if has_clip_sample:
|
||||||
|
sched_init_args["clip_sample"] = False
|
||||||
|
|
||||||
# samplerの乱数をあらかじめ指定するための処理
|
# samplerの乱数をあらかじめ指定するための処理
|
||||||
|
|
||||||
@@ -1397,10 +1416,11 @@ def main(args):
|
|||||||
**sched_init_args,
|
**sched_init_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# clip_sample=Trueにする
|
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
# # clip_sample=Trueにする
|
||||||
print("set clip_sample to True")
|
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||||
scheduler.config.clip_sample = True
|
# print("set clip_sample to True")
|
||||||
|
# scheduler.config.clip_sample = True
|
||||||
|
|
||||||
# deviceを決定する
|
# deviceを決定する
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||||
|
|||||||
Reference in New Issue
Block a user