Adding an sample_sampler dpm++_sde_k

This commit is contained in:
jerrisk
2024-02-02 13:30:10 +08:00
parent cd19df49cd
commit 9ce4c6e5f2

View File

@@ -39,6 +39,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
DPMSolverSDEScheduler,
StableDiffusionPipeline,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
@@ -3072,6 +3073,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"dpm++_sde_k"
],
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
)
@@ -4556,6 +4558,10 @@ def get_my_scheduler(
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sample_sampler == "dpm++_sde_k":
scheduler_cls = DPMSolverSDEScheduler
sched_init_args["noise_sampler_seed"] = 0
sched_init_args["steps_offset"] = 1
else:
scheduler_cls = DDIMScheduler
@@ -4569,7 +4575,8 @@ def get_my_scheduler(
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
if sample_sampler == "dpm++_sde_k":
scheduler.config.use_karras_sigmas=True
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
@@ -4889,4 +4896,4 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
return self.loss_total / len(self.loss_list)