mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
suppress waning for scheduler args #748
This commit is contained in:
@@ -1309,7 +1309,10 @@ def main(args):
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
has_steps_offset = True
|
||||
has_clip_sample = True
|
||||
scheduler_num_noises_per_step = 1
|
||||
|
||||
if args.sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_ddim
|
||||
@@ -1319,32 +1322,48 @@ def main(args):
|
||||
elif args.sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_pndm
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "lms" or args.sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "euler" or args.sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = args.sampler
|
||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
|
||||
has_clip_sample = False
|
||||
has_steps_offset = False
|
||||
elif args.sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
|
||||
scheduler_num_noises_per_step = 2
|
||||
has_clip_sample = False
|
||||
|
||||
# 警告を出さないようにする
|
||||
if has_steps_offset:
|
||||
sched_init_args["steps_offset"] = 1
|
||||
if has_clip_sample:
|
||||
sched_init_args["clip_sample"] = False
|
||||
|
||||
# samplerの乱数をあらかじめ指定するための処理
|
||||
|
||||
@@ -1397,10 +1416,11 @@ def main(args):
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||
# # clip_sample=Trueにする
|
||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
# scheduler.config.clip_sample = True
|
||||
|
||||
# deviceを決定する
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||
|
||||
Reference in New Issue
Block a user