diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0f2b3bd0..480f3dc6 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2907,6 +2907,7 @@ def main(args): resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) + templist = [] for list_of_prompts in resorted_list: templist.extend(list_of_prompts) @@ -2945,6 +2946,7 @@ def main(args): ext_separated_list_of_batches[x] = holder[:] ''' distributed_state.wait_for_everyone() + ''' if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): @@ -2954,6 +2956,7 @@ def main(args): for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) + ''' ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") @@ -2967,19 +2970,21 @@ def main(args): batchlogstr += f"\nImage: {batch[i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batch[i].base.prompt}\nNegative Prompt: {batch[i].base.negative_prompt}\nSeed: {batch[i].base.seed}" logger.info(batchlogstr) coll_image, coll_metadata, coll_filename = process_batch(batch, distributed_state, highres_fix) - logger.info(f"coll_image: {len(coll_image)}") - logger.info(f"coll_metadata: {len(coll_metadata)}") - logger.info(f"coll_filename: {len(coll_filename)}") + #logger.info(f"coll_image: {len(coll_image)}") + #logger.info(f"coll_metadata: {len(coll_metadata)}") + #logger.info(f"coll_filename: {len(coll_filename)}") distributed_state.wait_for_everyone() all_images = gather_object(coll_image) all_metadatas = gather_object(coll_metadata) all_filenames = gather_object(coll_filename) prev_image = all_images[0] + ''' if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): logger.info(f"Saving image: {filename}") image.save(os.path.join(args.outdir, filename), pnginfo=metadata) + ''' distributed_state.wait_for_everyone() #for i in range(len(data_loader)):