From 9ce4c6e5f29a6fbb1aef0330dbf4ce7f1398b747 Mon Sep 17 00:00:00 2001 From: jerrisk <15950515760@163.com> Date: Fri, 2 Feb 2024 13:30:10 +0800 Subject: [PATCH] Adding an sample_sampler dpm++_sde_k --- library/train_util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba428e50..1b0d5016 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -39,6 +39,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti import transformers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( + DPMSolverSDEScheduler, StableDiffusionPipeline, DDPMScheduler, EulerAncestralDiscreteScheduler, @@ -3072,6 +3073,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "k_euler_a", "k_dpm_2", "k_dpm_2_a", + "dpm++_sde_k" ], help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", ) @@ -4556,6 +4558,10 @@ def get_my_scheduler( scheduler_cls = KDPM2DiscreteScheduler elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": scheduler_cls = KDPM2AncestralDiscreteScheduler + elif sample_sampler == "dpm++_sde_k": + scheduler_cls = DPMSolverSDEScheduler + sched_init_args["noise_sampler_seed"] = 0 + sched_init_args["steps_offset"] = 1 else: scheduler_cls = DDIMScheduler @@ -4569,7 +4575,8 @@ def get_my_scheduler( beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args, ) - + if sample_sampler == "dpm++_sde_k": + scheduler.config.use_karras_sigmas=True # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # print("set clip_sample to True") @@ -4889,4 +4896,4 @@ class LossRecorder: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + return self.loss_total / len(self.loss_list) \ No newline at end of file