From 01e00ac1b085c562ccc14a1c166c43ccf90d2a83 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 19:45:44 +0900 Subject: [PATCH 1/4] Make a function get_my_scheduler() --- library/train_util.py | 98 ++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 51610e70..d6a5221a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4381,6 +4381,59 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is False + ): + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) @@ -4438,50 +4491,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, + scheduler = get_my_scheduler( + sample_sampler=args.scheduler, + v_parameterization=args.v_parameterization, ) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True - pipeline = pipe_class( text_encoder=text_encoder, vae=vae, From 291c29caaf2f17e4c61b523522d7453df8a1c480 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 19:57:25 +0900 Subject: [PATCH 2/4] Added a function line_to_prompt_dict() and removed duplicated initializations --- library/train_util.py | 124 +++++++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 63 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d6a5221a..b93b8ea4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4439,6 +4439,55 @@ def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict['prompt'] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['width'] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['height'] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['seed'] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict['scale'] = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['negative_prompt'] = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['controlnet_image'] = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + return prompt_dict + def sample_images_common( pipe_class, accelerator, @@ -4517,73 +4566,22 @@ def sample_images_common( with torch.no_grad(): # with accelerator.autocast(): - for i, prompt in enumerate(prompts): + for i, prompt_dict in enumerate(prompts): if not accelerator.is_main_process: continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") if seed is not None: torch.manual_seed(seed) From cf876fcdb40d46c1bd21d50106ea44ada9f45671 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 20:15:04 +0900 Subject: [PATCH 3/4] Accept --ss to set sample_sampler dynamically --- library/train_util.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b93b8ea4..949a8206 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4477,6 +4477,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict['negative_prompt'] = m.group(1) continue + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['sample_sampler'] = m.group(1) + continue + m = re.match(r"cn (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict['controlnet_image'] = m.group(1) @@ -4540,17 +4545,19 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - scheduler = get_my_scheduler( - sample_sampler=args.scheduler, + schedulers: dict = {} + default_scheduler = get_my_scheduler( + sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) + schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=default_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -4582,11 +4589,18 @@ def sample_images_common( seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") + sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler) if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + scheduler = schedulers.get(sampler_name) + if scheduler is None: + scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,) + schedulers[sampler_name] = scheduler + pipeline.scheduler = scheduler + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -4604,6 +4618,7 @@ def sample_images_common( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") with accelerator.autocast(): latents = pipeline( prompt=prompt, From 40d917b0fecc2459dff8a3848d7ea0c7d6c21ccb Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:02:44 +0900 Subject: [PATCH 4/4] Removed incorrect comments --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 949a8206..edff4f49 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4478,12 +4478,12 @@ def line_to_prompt_dict(line: str) -> dict: continue m = re.match(r"ss (.+)", parg, re.IGNORECASE) - if m: # negative prompt + if m: prompt_dict['sample_sampler'] = m.group(1) continue m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt + if m: prompt_dict['controlnet_image'] = m.group(1) continue