diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0765bab1..5d5cead2 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1501,10 +1501,13 @@ def main(args): logger.info(f"preferred device: {device}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone() # xformers、Hypernetwork対応 @@ -2443,9 +2446,7 @@ def main(args): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - if distributed_state.is_main_process: - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2555,7 +2556,7 @@ def main(args): logger.info(f"scale: {scale}") continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) if m: # negative scale if m.group(1).lower() == "none": negative_scale = None @@ -2844,31 +2845,29 @@ def main(args): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() - if distributed_state.is_main_process: - batch_data.append(b1) + batch_data.clear() + + batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: - if distributed_state.is_main_process: - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - test_batch_data_split = [] - test_batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - test_batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - if (i+1) % 4 == 0: - test_batch_data_split.append(test_batch_index.copy()) - test_batch_index.clear() - for i in range(len(test_batch_data_split)): - logger.info(f"test_batch_data_split[{i}]:") - for j in range(len(test_batch_data_split[i])): - logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") + logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] + test_batch_data_split = [] + test_batch_index = [] + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + test_batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + if (i+1) % 4 == 0: + test_batch_data_split.append(test_batch_index.copy()) + test_batch_index.clear() + for i in range(len(test_batch_data_split)): + logger.info(f"test_batch_data_split[{i}]:") + for j in range(len(test_batch_data_split[i])): + logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: @@ -2878,8 +2877,7 @@ def main(args): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() + batch_data.clear() global_step += 1 @@ -2888,14 +2886,12 @@ def main(args): if len(batch_data) > 0: batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - if distributed_state.is_main_process: - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - logger.info(f"{batch_data_split}") + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") @@ -2904,8 +2900,7 @@ def main(args): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() + batch_data.clear() logger.info("done!")