suppress waning for scheduler args #748

This commit is contained in:
Kohya S
2023-08-11 21:31:55 +09:00
parent 3307ccb2dc
commit 8415014de6

View File

@@ -1309,7 +1309,10 @@ def main(args):
# schedulerを用意する
sched_init_args = {}
has_steps_offset = True
has_clip_sample = True
scheduler_num_noises_per_step = 1
if args.sampler == "ddim":
scheduler_cls = DDIMScheduler
scheduler_module = diffusers.schedulers.scheduling_ddim
@@ -1319,32 +1322,48 @@ def main(args):
elif args.sampler == "pndm":
scheduler_cls = PNDMScheduler
scheduler_module = diffusers.schedulers.scheduling_pndm
has_clip_sample = False
elif args.sampler == "lms" or args.sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
has_clip_sample = False
elif args.sampler == "euler" or args.sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
has_clip_sample = False
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
has_clip_sample = False
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sampler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
has_clip_sample = False
elif args.sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
has_clip_sample = False
has_steps_offset = False
elif args.sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
has_clip_sample = False
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
has_clip_sample = False
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2
has_clip_sample = False
# 警告を出さないようにする
if has_steps_offset:
sched_init_args["steps_offset"] = 1
if has_clip_sample:
sched_init_args["clip_sample"] = False
# samplerの乱数をあらかじめ指定するための処理
@@ -1397,10 +1416,11 @@ def main(args):
**sched_init_args,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
print("set clip_sample to True")
scheduler.config.clip_sample = True
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
# # clip_sample=Trueにする
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
# scheduler.config.clip_sample = True
# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない