suppress waning for scheduler args #748

This commit is contained in:
Kohya S
2023-08-11 21:31:55 +09:00
parent 3307ccb2dc
commit 8415014de6

View File

@@ -1309,7 +1309,10 @@ def main(args):
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}
has_steps_offset = True
has_clip_sample = True
scheduler_num_noises_per_step = 1 scheduler_num_noises_per_step = 1
if args.sampler == "ddim": if args.sampler == "ddim":
scheduler_cls = DDIMScheduler scheduler_cls = DDIMScheduler
scheduler_module = diffusers.schedulers.scheduling_ddim scheduler_module = diffusers.schedulers.scheduling_ddim
@@ -1319,32 +1322,48 @@ def main(args):
elif args.sampler == "pndm": elif args.sampler == "pndm":
scheduler_cls = PNDMScheduler scheduler_cls = PNDMScheduler
scheduler_module = diffusers.schedulers.scheduling_pndm scheduler_module = diffusers.schedulers.scheduling_pndm
has_clip_sample = False
elif args.sampler == "lms" or args.sampler == "k_lms": elif args.sampler == "lms" or args.sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler scheduler_cls = LMSDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_lms_discrete scheduler_module = diffusers.schedulers.scheduling_lms_discrete
has_clip_sample = False
elif args.sampler == "euler" or args.sampler == "k_euler": elif args.sampler == "euler" or args.sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler scheduler_cls = EulerDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_discrete scheduler_module = diffusers.schedulers.scheduling_euler_discrete
has_clip_sample = False
elif args.sampler == "euler_a" or args.sampler == "k_euler_a": elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler scheduler_cls = EulerAncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
has_clip_sample = False
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sampler sched_init_args["algorithm_type"] = args.sampler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
has_clip_sample = False
elif args.sampler == "dpmsingle": elif args.sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler scheduler_cls = DPMSolverSinglestepScheduler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
has_clip_sample = False
has_steps_offset = False
elif args.sampler == "heun": elif args.sampler == "heun":
scheduler_cls = HeunDiscreteScheduler scheduler_cls = HeunDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_heun_discrete scheduler_module = diffusers.schedulers.scheduling_heun_discrete
has_clip_sample = False
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler scheduler_cls = KDPM2DiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
has_clip_sample = False
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_cls = KDPM2AncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2 scheduler_num_noises_per_step = 2
has_clip_sample = False
# 警告を出さないようにする
if has_steps_offset:
sched_init_args["steps_offset"] = 1
if has_clip_sample:
sched_init_args["clip_sample"] = False
# samplerの乱数をあらかじめ指定するための処理 # samplerの乱数をあらかじめ指定するための処理
@@ -1397,10 +1416,11 @@ def main(args):
**sched_init_args, **sched_init_args,
) )
# clip_sample=Trueにする # ↓以下は結局PipeでFalseに設定されるので意味がなかった
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # # clip_sample=Trueにする
print("set clip_sample to True") # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
scheduler.config.clip_sample = True # print("set clip_sample to True")
# scheduler.config.clip_sample = True
# deviceを決定する # deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない