From 63508160911936b97a356b09c855c05583fbf51b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:10:53 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 67 ++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 20122cef..b30edc11 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2420,9 +2420,6 @@ def main(args): else: fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - #logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") - if negative_prompt is not None: - # logger.info(f"Negative Prompt: {negative_prompt}\n") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: @@ -2443,7 +2440,7 @@ def main(args): # 画像生成のプロンプトが一周するまでのループ prompt_index = 0 global_step = 0 - batch_data = [] + prompt_data_list = [] extinfo = [] while args.interactive or (prompt_index < len(prompt_list) and distributed_state.is_main_process): if len(prompt_list) == 0: @@ -2867,26 +2864,24 @@ def main(args): num_sub_prompts, ), ) - batch_data.append(b1) + prompt_data_list.append(b1) extinfo.append(b1.ext) global_step += 1 prompt_index += 1 - batch_data = gather_object(batch_data) + prompt_data_list = gather_object(prompt_data_list) extinfo = gather_object(extinfo) - logger.info(f"batch_data line 2878: {len(batch_data)}") - batch_separated_list = [] - logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") - if len(batch_data) > 0 and distributed_state.is_main_process: + + ext_separated_list_of_batches = [] + if len(prompt_data_list) > 0 and distributed_state.is_main_process: unique_extinfo = list(set(extinfo)) - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}") + logger.info(f"Device {distributed_state.device}, prompt_data_list line 2880: {len(prompt_data_list)}, len(unique_extinfo): {len(unique_extinfo)}") # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}") templist = [] - res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] + res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: - templist.append(batch_data[index]) + templist.append(prompt_data_list[index]) split_into_batches = get_batches(items=templist, batch_size=args.batch_size) if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch @@ -2899,28 +2894,34 @@ def main(args): split_into_batches = [] n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - batch_separated_list.append(split_into_batches) - logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") + ext_separated_list_of_batches.append(split_into_batches) if distributed_state.num_processes > 1: - logger.info(f"batch_separated_list: {len(batch_separated_list)}") - temp_list = [] - for ext_batch in batch_separated_list: + + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] for i in range(distributed_state.num_processes): - temp_list.append(ext_batch[i :: distributed_state.num_processes]) - logger.info(f"templist: {len(temp_list)}") - batch_separated_list = [] - for sub_batch_list in temp_list: - batch_separated_list.append(sub_batch_list) - logger.info(f"batch_separated_list: {len(batch_separated_list)}") - logger.info(f"sub_batch_list: {len(sub_batch_list)}") + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + ext_separated_list_of_batches[x] = [] + for batches in temp_list: + ext_separated_list_of_batches[x].extend(batches) + logger.info(f"templist: {len(temp_list)}, ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") + + logger.info(f"ext_separated_list_of_batches: {len(ext_separated_list_of_batches)}") + count_prompts = 0 + for sub_batch_list in ext_separated_list_of_batches: + logger.info(f" sub_batch_list: {len(sub_batch_list)}") + for batches in sub_batch_list: + logger.info(f" batches: {len(batches)}") + count_prompts += len(batches) + logger.info(f"count_prompts: {count_prompts}") + distributed_state.wait_for_everyone() - batch_separated_list = gather_object(batch_separated_list) - logger.info(f"batch_data line 2911: {len(batch_data)}") + ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo - if len(batch_separated_list) > 0: - for batch_list in batch_separated_list: + if len(ext_separated_list_of_batches) > 0: + for batch_list in ext_separated_list_of_batches: with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") @@ -2932,17 +2933,17 @@ def main(args): distributed_state.wait_for_everyone() #for i in range(len(data_loader)): # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") - # batch_data_split.append(data_loader[i]) + # prompt_data_list_split.append(data_loader[i]) # if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): # continue # with torch.no_grad(): - # with distributed_state.split_between_processes(batch_data_split) as batch_list: + # with distributed_state.split_between_processes(prompt_data_list_split) as batch_list: # logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") # logger.info(f"batch_list:") # for i in range(len(batch_list[0])): # logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") # prev_image = process_batch(batch_list[0], highres_fix)[0] - # batch_data_split.clear() + # prompt_data_list_split.clear() # distributed_state.wait_for_everyone() logger.info("done!")