diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 79663609..260a7ffe 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,12 +2916,13 @@ def main(args): if len(ext_separated_list_of_batches) > 0: for batch_list in ext_separated_list_of_batches: - with distributed_state.split_between_processes(batch_list) as batches: - for j in range(len(batches)): - logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") - for i in range(len(batches[j])): - logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] + with torch.no_grad(): + with distributed_state.split_between_processes(batch_list) as batches: + for j in range(len(batches)): + logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") + for i in range(len(batches[j])): + logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)):