diff --git a/library/train_util.py b/library/train_util.py index 51610e70..d6a5221a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4381,6 +4381,59 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is False + ): + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) @@ -4438,50 +4491,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, + scheduler = get_my_scheduler( + sample_sampler=args.scheduler, + v_parameterization=args.v_parameterization, ) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True - pipeline = pipe_class( text_encoder=text_encoder, vae=vae,