diff --git a/library/train_util.py b/library/train_util.py index b9d08f25..309c2488 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -48,6 +48,7 @@ from diffusers.optimization import ( ) from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( + DPMSolverSDEScheduler, StableDiffusionPipeline, DDPMScheduler, EulerAncestralDiscreteScheduler, @@ -3534,6 +3535,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "k_euler_a", "k_dpm_2", "k_dpm_2_a", + "dpm++_sde_k" ], help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", ) @@ -5342,6 +5344,10 @@ def get_my_scheduler( scheduler_cls = KDPM2DiscreteScheduler elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": scheduler_cls = KDPM2AncestralDiscreteScheduler + elif sample_sampler == "dpm++_sde_k": + scheduler_cls = DPMSolverSDEScheduler + sched_init_args["noise_sampler_seed"] = 0 + sched_init_args["steps_offset"] = 1 else: scheduler_cls = DDIMScheduler @@ -5355,7 +5361,8 @@ def get_my_scheduler( beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args, ) - + if sample_sampler == "dpm++_sde_k": + scheduler.config.use_karras_sigmas=True # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # logger.info("set clip_sample to True") @@ -5729,4 +5736,4 @@ class LossRecorder: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + return self.loss_total / len(self.loss_list) \ No newline at end of file