mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #906 from shirayu/accept_scheduler_designation_in_training
Accept sampler designation in sampling of training
This commit is contained in:
@@ -4447,11 +4447,118 @@ SCHEDULER_LINEAR_END = 0.0120
|
|||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
def get_my_scheduler(
|
||||||
|
*,
|
||||||
|
sample_sampler: str,
|
||||||
|
v_parameterization: bool,
|
||||||
|
):
|
||||||
|
sched_init_args = {}
|
||||||
|
if sample_sampler == "ddim":
|
||||||
|
scheduler_cls = DDIMScheduler
|
||||||
|
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||||
|
scheduler_cls = DDPMScheduler
|
||||||
|
elif sample_sampler == "pndm":
|
||||||
|
scheduler_cls = PNDMScheduler
|
||||||
|
elif sample_sampler == "lms" or sample_sampler == "k_lms":
|
||||||
|
scheduler_cls = LMSDiscreteScheduler
|
||||||
|
elif sample_sampler == "euler" or sample_sampler == "k_euler":
|
||||||
|
scheduler_cls = EulerDiscreteScheduler
|
||||||
|
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
|
||||||
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||||
|
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
|
||||||
|
scheduler_cls = DPMSolverMultistepScheduler
|
||||||
|
sched_init_args["algorithm_type"] = sample_sampler
|
||||||
|
elif sample_sampler == "dpmsingle":
|
||||||
|
scheduler_cls = DPMSolverSinglestepScheduler
|
||||||
|
elif sample_sampler == "heun":
|
||||||
|
scheduler_cls = HeunDiscreteScheduler
|
||||||
|
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
|
||||||
|
scheduler_cls = KDPM2DiscreteScheduler
|
||||||
|
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
|
||||||
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||||
|
else:
|
||||||
|
scheduler_cls = DDIMScheduler
|
||||||
|
|
||||||
|
if v_parameterization:
|
||||||
|
sched_init_args["prediction_type"] = "v_prediction"
|
||||||
|
|
||||||
|
scheduler = scheduler_cls(
|
||||||
|
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||||
|
beta_start=SCHEDULER_LINEAR_START,
|
||||||
|
beta_end=SCHEDULER_LINEAR_END,
|
||||||
|
beta_schedule=SCHEDLER_SCHEDULE,
|
||||||
|
**sched_init_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clip_sample=Trueにする
|
||||||
|
if (
|
||||||
|
hasattr(scheduler.config, "clip_sample")
|
||||||
|
and scheduler.config.clip_sample is False
|
||||||
|
):
|
||||||
|
# print("set clip_sample to True")
|
||||||
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
def sample_images(*args, **kwargs):
|
def sample_images(*args, **kwargs):
|
||||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def line_to_prompt_dict(line: str) -> dict:
|
||||||
|
# subset of gen_img_diffusers
|
||||||
|
prompt_args = line.split(" --")
|
||||||
|
prompt_dict = {}
|
||||||
|
prompt_dict['prompt'] = prompt_args[0]
|
||||||
|
|
||||||
|
for parg in prompt_args:
|
||||||
|
try:
|
||||||
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['width'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['height'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['seed'] = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m: # steps
|
||||||
|
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # scale
|
||||||
|
prompt_dict['scale'] = float(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
prompt_dict['negative_prompt'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['sample_sampler'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
prompt_dict['controlnet_image'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except ValueError as ex:
|
||||||
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
|
print(ex)
|
||||||
|
|
||||||
|
return prompt_dict
|
||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -4504,56 +4611,19 @@ def sample_images_common(
|
|||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
prompts = json.load(f)
|
prompts = json.load(f)
|
||||||
|
|
||||||
# schedulerを用意する
|
schedulers: dict = {}
|
||||||
sched_init_args = {}
|
default_scheduler = get_my_scheduler(
|
||||||
if args.sample_sampler == "ddim":
|
sample_sampler=args.sample_sampler,
|
||||||
scheduler_cls = DDIMScheduler
|
v_parameterization=args.v_parameterization,
|
||||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
|
||||||
scheduler_cls = DDPMScheduler
|
|
||||||
elif args.sample_sampler == "pndm":
|
|
||||||
scheduler_cls = PNDMScheduler
|
|
||||||
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
|
||||||
scheduler_cls = LMSDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
|
||||||
scheduler_cls = EulerDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
|
||||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
|
||||||
scheduler_cls = DPMSolverMultistepScheduler
|
|
||||||
sched_init_args["algorithm_type"] = args.sample_sampler
|
|
||||||
elif args.sample_sampler == "dpmsingle":
|
|
||||||
scheduler_cls = DPMSolverSinglestepScheduler
|
|
||||||
elif args.sample_sampler == "heun":
|
|
||||||
scheduler_cls = HeunDiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
|
||||||
scheduler_cls = KDPM2DiscreteScheduler
|
|
||||||
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
|
||||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
|
||||||
else:
|
|
||||||
scheduler_cls = DDIMScheduler
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
sched_init_args["prediction_type"] = "v_prediction"
|
|
||||||
|
|
||||||
scheduler = scheduler_cls(
|
|
||||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
|
||||||
beta_start=SCHEDULER_LINEAR_START,
|
|
||||||
beta_end=SCHEDULER_LINEAR_END,
|
|
||||||
beta_schedule=SCHEDLER_SCHEDULE,
|
|
||||||
**sched_init_args,
|
|
||||||
)
|
)
|
||||||
|
schedulers[args.sample_sampler] = default_scheduler
|
||||||
# clip_sample=Trueにする
|
|
||||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
|
||||||
# print("set clip_sample to True")
|
|
||||||
scheduler.config.clip_sample = True
|
|
||||||
|
|
||||||
pipeline = pipe_class(
|
pipeline = pipe_class(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=default_scheduler,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
@@ -4569,78 +4639,34 @@ def sample_images_common(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# with accelerator.autocast():
|
# with accelerator.autocast():
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt_dict in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt_dict, str):
|
||||||
negative_prompt = prompt.get("negative_prompt")
|
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||||
sample_steps = prompt.get("sample_steps", 30)
|
|
||||||
width = prompt.get("width", 512)
|
|
||||||
height = prompt.get("height", 512)
|
|
||||||
scale = prompt.get("scale", 7.5)
|
|
||||||
seed = prompt.get("seed")
|
|
||||||
controlnet_image = prompt.get("controlnet_image")
|
|
||||||
prompt = prompt.get("prompt")
|
|
||||||
else:
|
|
||||||
# prompt = prompt.strip()
|
|
||||||
# if len(prompt) == 0 or prompt[0] == "#":
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# subset of gen_img_diffusers
|
assert isinstance(prompt_dict, dict)
|
||||||
prompt_args = prompt.split(" --")
|
negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
prompt = prompt_args[0]
|
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||||
negative_prompt = None
|
width = prompt_dict.get("width", 512)
|
||||||
sample_steps = 30
|
height = prompt_dict.get("height", 512)
|
||||||
width = height = 512
|
scale = prompt_dict.get("scale", 7.5)
|
||||||
scale = 7.5
|
seed = prompt_dict.get("seed")
|
||||||
seed = None
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
controlnet_image = None
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
for parg in prompt_args:
|
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
try:
|
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
width = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
height = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
seed = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m: # steps
|
|
||||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
|
||||||
if m: # scale
|
|
||||||
scale = float(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
negative_prompt = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
controlnet_image = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except ValueError as ex:
|
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
|
||||||
print(ex)
|
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scheduler = schedulers.get(sampler_name)
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
|
||||||
|
schedulers[sampler_name] = scheduler
|
||||||
|
pipeline.scheduler = scheduler
|
||||||
|
|
||||||
if prompt_replacement is not None:
|
if prompt_replacement is not None:
|
||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
@@ -4658,6 +4684,7 @@ def sample_images_common(
|
|||||||
print(f"width: {width}")
|
print(f"width: {width}")
|
||||||
print(f"sample_steps: {sample_steps}")
|
print(f"sample_steps: {sample_steps}")
|
||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
|
print(f"sample_sampler: {sampler_name}")
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
latents = pipeline(
|
latents = pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
Reference in New Issue
Block a user