diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 2885f78e..db093707 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2487,231 +2487,231 @@ def main(args): logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: - if distributed_state.is_main_process: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2827,12 +2827,11 @@ def main(args): if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(batch_data)): if distributed_state.is_main_process: - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) @@ -2857,7 +2856,7 @@ def main(args): test_batch_data_split = [] test_batch_index = [] for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}}") batch_index.append(batch_data[i]) test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: @@ -2891,7 +2890,7 @@ def main(args): batch_index = [] if distributed_state.is_main_process: for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) @@ -2900,7 +2899,7 @@ def main(args): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list: {batch_list}") + logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0]