From 55179ad8094617d60a4f3810c41ff1a3e5fcdbf3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 16:29:58 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 79663609..260a7ffe 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,12 +2916,13 @@ def main(args): if len(ext_separated_list_of_batches) > 0: for batch_list in ext_separated_list_of_batches: - with distributed_state.split_between_processes(batch_list) as batches: - for j in range(len(batches)): - logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") - for i in range(len(batches[j])): - logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] + with torch.no_grad(): + with distributed_state.split_between_processes(batch_list) as batches: + for j in range(len(batches)): + logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") + for i in range(len(batches[j])): + logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)):