Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-26 00:38:06 +08:00
committed by GitHub
parent b7781f9b25
commit 04739d7cab

View File

@@ -15,6 +15,7 @@ import random
import re
import gc
from accelerate import PartialState
from accelerate.utils import gather_object
import diffusers
import numpy as np
@@ -81,7 +82,17 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
高速化のためのモジュール入れ替え
"""
def get_batches(items, batch_size):
num_batches = (len(items) + batch_size - 1) // batch_size
batches = []
for i in range(num_batches):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(items))
batch = items[start_index:end_index]
batches.append(batch)
return batches
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")
@@ -2442,458 +2453,422 @@ def main(args):
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
raw_prompts = []
if distributed_state.is_main_process:
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
else:
distributed_state.wait_for_everyone()
raw_prompts = gather_object(raw_prompts)
if distributed_state.is_main_process:
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
width = args.W
height = args.H
original_width = args.original_width
original_height = args.original_height
original_width_negative = args.original_width_negative
original_height_negative = args.original_height_negative
crop_top = args.crop_top
crop_left = args.crop_left
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
# Gradual Latent
gl_timesteps = None # means no override
gl_ratio = args.gradual_latent_ratio
gl_every_n_steps = args.gradual_latent_every_n_steps
gl_ratio_step = args.gradual_latent_ratio_step
gl_s_noise = args.gradual_latent_s_noise
gl_unsharp_params = args.gradual_latent_unsharp_params
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
logger.info(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
logger.info(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
logger.info(f"original width: {original_width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
logger.info(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
logger.info(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
logger.info(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
logger.info(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
logger.info(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
logger.info(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
logger.info(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == "none":
negative_scale = None
else:
negative_scale = float(m.group(1))
logger.info(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
logger.info(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
logger.info(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
logger.info(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
logger.info(f"network mul: {network_muls}")
continue
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
width = args.W
height = args.H
original_width = args.original_width
original_height = args.original_height
original_width_negative = args.original_width_negative
original_height_negative = args.original_height_negative
crop_top = args.crop_top
crop_left = args.crop_left
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink ratio: {ds_ratio}")
continue
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# override Gradual Latent
if gl_timesteps is not None:
if gl_timesteps < 0:
gl_timesteps = args.gradual_latent_timesteps or 650
if gl_unsharp_params is not None:
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
gradual_latent = GradualLatent(
gl_ratio,
gl_timesteps,
gl_every_n_steps,
gl_ratio_step,
gl_s_noise,
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
gl_timesteps = None # means no override
gl_ratio = args.gradual_latent_ratio
gl_every_n_steps = args.gradual_latent_every_n_steps
gl_ratio_step = args.gradual_latent_ratio_step
gl_s_noise = args.gradual_latent_s_noise
gl_unsharp_params = args.gradual_latent_unsharp_params
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
logger.info(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
logger.info(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
logger.info(f"original width: {original_width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
logger.info(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
logger.info(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
logger.info(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
logger.info(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
logger.info(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
logger.info(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
logger.info(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == "none":
negative_scale = None
else:
negative_scale = float(m.group(1))
logger.info(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
logger.info(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
logger.info(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
logger.info(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
logger.info(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink ratio: {ds_ratio}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# override Gradual Latent
if gl_timesteps is not None:
if gl_timesteps < 0:
gl_timesteps = args.gradual_latent_timesteps or 650
if gl_unsharp_params is not None:
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:
logger.error("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
gradual_latent = GradualLatent(
gl_ratio,
gl_timesteps,
gl_every_n_steps,
gl_ratio_step,
gl_s_noise,
us_ksize,
us_sigma,
us_strength,
us_target_x,
)
pipe.set_gradual_latent(gradual_latent)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
seed = None # 前のを消す
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive:
logger.info(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None:
init_image = init_images[global_step % len(init_images)]
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
# 32単位に丸めたやつにresizeされるので踏襲する
if not highres_fix:
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
logger.warning(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if mask_images is not None:
mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None:
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c : p * c + c]
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
logger.error("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
else:
seed = None # 前のを消す
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive:
logger.info(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None:
init_image = init_images[global_step % len(init_images)]
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
# 32単位に丸めたやつにresizeされるので踏襲する
if not highres_fix:
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
logger.warning(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if mask_images is not None:
mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None:
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c : p * c + c]
else:
guide_image = guide_images[global_step % len(guide_images)]
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
guide_image = guide_images[global_step % len(guide_images)]
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
),
BatchDataExt(
width,
height,
original_width,
original_height,
original_width_negative,
original_height_negative,
crop_left,
crop_top,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.")
logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.")
batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)]
batch_index = []
for i in range(len(batch_data)):
logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}")
batch_index.append(batch_data[i])
if (i+1) % args.batch_size == 0:
batch_data_split.append(batch_index.copy())
batch_index.clear()
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:")
logger.info(f"batch_list:")
for i in range(len(batch_list[0])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}")
prev_image = process_batch(batch_list[0], highres_fix)[0]
batch_data.clear()
distributed_state.wait_for_everyone()
batch_data.append(b1)
if len(batch_data) == args.batch_size*distributed_state.num_processes:
logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.")
batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)]
batch_index = []
for i in range(len(batch_data)):
logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}")
batch_index.append(batch_data[i])
if (i+1) % args.batch_size == 0:
batch_data_split.append(batch_index.copy())
batch_index.clear()
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:")
logger.info(f"batch_list:")
for i in range(len(batch_list[0])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}")
prev_image = process_batch(batch_list[0], highres_fix)[0]
batch_data.clear()
distributed_state.wait_for_everyone()
global_step += 1
prompt_index += 1
num_sub_prompts = None
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
),
BatchDataExt(
width,
height,
original_width,
original_height,
original_width_negative,
original_height_negative,
crop_left,
crop_top,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
batch_data.extend(b1)
global_step += 1
prompt_index += 1
else:
distributed_state.wait_for_everyone()
batch_data = gather_object(batch_data)
logger.info(f"Total prompts: {len(batch_data)}")
if len(batch_data) > 0:
data_loader = get_batches(items=batch_data, batch_size=args.batch_size)
logger.info(f"Total batches: {len(batch_data)}")
batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)]
batch_index = []
for i in range(len(batch_data)):
logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}")
batch_index.append(batch_data[i])
if (i+1) % args.batch_size == 0:
batch_data_split.append(batch_index.copy())
batch_index.clear()
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:")
logger.info(f"batch_list:")
for i in range(len(batch_list[0])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}")
prev_image = process_batch(batch_list[0], highres_fix)[0]
batch_data.clear()
distributed_state.wait_for_everyone()
for i in range(len(data_loader)):
logger.info(f"Loading Batch {i+1} of {len(data_loader)}")
batch_data_split.append(data_loader[i])
if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader)
continue
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:")
logger.info(f"batch_list:")
for i in range(len(batch_list[0])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}")
prev_image = process_batch(batch_list[0], highres_fix)[0]
batch_data_split.clear()
distributed_state.wait_for_everyone()
logger.info("done!")