From 04739d7cab1b5884144dcab341f51813b6aadfe2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 00:38:06 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 859 ++++++++++++++++++++---------------------- 1 file changed, 417 insertions(+), 442 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 72425826..14f5e88c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -15,6 +15,7 @@ import random import re import gc from accelerate import PartialState +from accelerate.utils import gather_object import diffusers import numpy as np @@ -81,7 +82,17 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" 高速化のためのモジュール入れ替え """ +def get_batches(items, batch_size): + num_batches = (len(items) + batch_size - 1) // batch_size + batches = [] + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(items)) + batch = items[start_index:end_index] + batches.append(batch) + + return batches def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") @@ -2442,458 +2453,422 @@ def main(args): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + raw_prompts = [] + if distributed_state.is_main_process: + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + else: + distributed_state.wait_for_everyone() + raw_prompts = gather_object(raw_prompts) + if distributed_state.is_main_process: # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 or len(raw_prompts) > 1: - - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + if pi == 0 or len(raw_prompts) > 1: + + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) else: - logger.error("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.warning( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + logger.error("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - batch_data.append(b1) - if len(batch_data) == args.batch_size*distributed_state.num_processes: - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - - global_step += 1 - - prompt_index += 1 - + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + batch_data.extend(b1) + global_step += 1 + + prompt_index += 1 + else: + distributed_state.wait_for_everyone() + batch_data = gather_object(batch_data) + logger.info(f"Total prompts: {len(batch_data)}") if len(batch_data) > 0: + data_loader = get_batches(items=batch_data, batch_size=args.batch_size) + logger.info(f"Total batches: {len(batch_data)}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - + for i in range(len(data_loader)): + logger.info(f"Loading Batch {i+1} of {len(data_loader)}") + batch_data_split.append(data_loader[i]) + if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader) + continue + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"batch_list:") + for i in range(len(batch_list[0])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") + prev_image = process_batch(batch_list[0], highres_fix)[0] + batch_data_split.clear() + distributed_state.wait_for_everyone() logger.info("done!")