Merge pull request #906 from shirayu/accept_scheduler_designation_in_training

Accept sampler designation in sampling of training
This commit is contained in:
Kohya S
2023-12-03 20:46:16 +09:00
committed by GitHub

View File

@@ -4447,11 +4447,118 @@ SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000 SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear" SCHEDLER_SCHEDULE = "scaled_linear"
def get_my_scheduler(
*,
sample_sampler: str,
v_parameterization: bool,
):
sched_init_args = {}
if sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sample_sampler == "lms" or sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sample_sampler == "euler" or sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sample_sampler
elif sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
# clip_sample=Trueにする
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is False
):
# print("set clip_sample to True")
scheduler.config.clip_sample = True
return scheduler
def sample_images(*args, **kwargs): def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def line_to_prompt_dict(line: str) -> dict:
# subset of gen_img_diffusers
prompt_args = line.split(" --")
prompt_dict = {}
prompt_dict['prompt'] = prompt_args[0]
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['width'] = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['height'] = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['seed'] = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
prompt_dict['scale'] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict['negative_prompt'] = m.group(1)
continue
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
if m:
prompt_dict['sample_sampler'] = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m:
prompt_dict['controlnet_image'] = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
return prompt_dict
def sample_images_common( def sample_images_common(
pipe_class, pipe_class,
accelerator, accelerator,
@@ -4504,56 +4611,19 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f: with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f) prompts = json.load(f)
# schedulerを用意する schedulers: dict = {}
sched_init_args = {} default_scheduler = get_my_scheduler(
if args.sample_sampler == "ddim": sample_sampler=args.sample_sampler,
scheduler_cls = DDIMScheduler v_parameterization=args.v_parameterization,
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
) )
schedulers[args.sample_sampler] = default_scheduler
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
pipeline = pipe_class( pipeline = pipe_class(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=default_scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
@@ -4569,78 +4639,34 @@ def sample_images_common(
with torch.no_grad(): with torch.no_grad():
# with accelerator.autocast(): # with accelerator.autocast():
for i, prompt in enumerate(prompts): for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process: if not accelerator.is_main_process:
continue continue
if isinstance(prompt, dict): if isinstance(prompt_dict, str):
negative_prompt = prompt.get("negative_prompt") prompt_dict = line_to_prompt_dict(prompt_dict)
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
# subset of gen_img_diffusers assert isinstance(prompt_dict, dict)
prompt_args = prompt.split(" --") negative_prompt = prompt_dict.get("negative_prompt")
prompt = prompt_args[0] sample_steps = prompt_dict.get("sample_steps", 30)
negative_prompt = None width = prompt_dict.get("width", 512)
sample_steps = 30 height = prompt_dict.get("height", 512)
width = height = 512 scale = prompt_dict.get("scale", 7.5)
scale = 7.5 seed = prompt_dict.get("seed")
seed = None controlnet_image = prompt_dict.get("controlnet_image")
controlnet_image = None prompt: str = prompt_dict.get("prompt", "")
for parg in prompt_args: sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler
if prompt_replacement is not None: if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None: if negative_prompt is not None:
@@ -4658,6 +4684,7 @@ def sample_images_common(
print(f"width: {width}") print(f"width: {width}")
print(f"sample_steps: {sample_steps}") print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}") print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
with accelerator.autocast(): with accelerator.autocast():
latents = pipeline( latents = pipeline(
prompt=prompt, prompt=prompt,