diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9f437606..aa3c5965 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2445,7 +2445,7 @@ def main(args): global_step = 0 batch_data = [] extinfo = [] - while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): + while args.interactive or (prompt_index < len(prompt_list) and distributed_state.is_main_process): if len(prompt_list) == 0: # interactive valid = False @@ -2873,6 +2873,7 @@ def main(args): prompt_index += 1 batch_data = gather_object(batch_data) + extinfo = gather_object(extinfo) logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}")