mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Adding an sample_sampler dpm++_sde_k
This commit is contained in:
@@ -39,6 +39,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
DPMSolverSDEScheduler,
|
||||
StableDiffusionPipeline,
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
@@ -3072,6 +3073,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"k_euler_a",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_a",
|
||||
"dpm++_sde_k"
|
||||
],
|
||||
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
|
||||
)
|
||||
@@ -4556,6 +4558,10 @@ def get_my_scheduler(
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
elif sample_sampler == "dpm++_sde_k":
|
||||
scheduler_cls = DPMSolverSDEScheduler
|
||||
sched_init_args["noise_sampler_seed"] = 0
|
||||
sched_init_args["steps_offset"] = 1
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
@@ -4569,7 +4575,8 @@ def get_my_scheduler(
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
if sample_sampler == "dpm++_sde_k":
|
||||
scheduler.config.use_karras_sigmas=True
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
@@ -4889,4 +4896,4 @@ class LossRecorder:
|
||||
|
||||
@property
|
||||
def moving_average(self) -> float:
|
||||
return self.loss_total / len(self.loss_list)
|
||||
return self.loss_total / len(self.loss_list)
|
||||
Reference in New Issue
Block a user