diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index de636c6e..a53b992b 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2843,9 +2843,9 @@ def main(args): logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() + prev_image = process_batch(batch_list[0], highres_fix)[0] batch_data.clear() + distributed_state.wait_for_everyone() batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: @@ -2876,8 +2876,9 @@ def main(args): for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() batch_data.clear() + distributed_state.wait_for_everyone() + global_step += 1 @@ -2899,8 +2900,9 @@ def main(args): for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() batch_data.clear() + distributed_state.wait_for_everyone() + logger.info("done!")