diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 87f377bd..0d50f20e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1513,13 +1513,12 @@ def main(args): logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone() # xformers、Hypernetwork対応 @@ -2897,8 +2896,14 @@ def main(args): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) + if distributed_state.num_processes > 1: + resorted_list = [] + for i in range(distributed_state.num_processes): + resorted_list.append(templist[i :: distributed_state.num_processes]) + for list in resorted_list: + templist.extend(list) split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() - ''' + sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch @@ -2917,7 +2922,7 @@ def main(args): sublist.reverse() n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - ''' + ext_separated_list_of_batches.append(split_into_batches) ''' logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")