This commit is contained in:
JudyBeaH
2025-09-30 09:58:45 +05:30
committed by GitHub

View File

@@ -48,6 +48,7 @@ from diffusers.optimization import (
)
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
DPMSolverSDEScheduler,
StableDiffusionPipeline,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
@@ -3534,6 +3535,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"dpm++_sde_k"
],
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
)
@@ -5342,6 +5344,10 @@ def get_my_scheduler(
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sample_sampler == "dpm++_sde_k":
scheduler_cls = DPMSolverSDEScheduler
sched_init_args["noise_sampler_seed"] = 0
sched_init_args["steps_offset"] = 1
else:
scheduler_cls = DDIMScheduler
@@ -5355,7 +5361,8 @@ def get_my_scheduler(
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
if sample_sampler == "dpm++_sde_k":
scheduler.config.use_karras_sigmas=True
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# logger.info("set clip_sample to True")
@@ -5729,4 +5736,4 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
return self.loss_total / len(self.loss_list)