diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bbe55f33..98d33413 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2454,10 +2454,6 @@ def main(args): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}") # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): @@ -2846,14 +2842,9 @@ def main(args): prompt_index += 1 distributed_state.wait_for_everyone() batch_data = gather_object(batch_data) - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"Total prompts: {len(batch_data)} for {distributed_state.local_process_index}") + if len(batch_data) > 0: data_loader = get_batches(items=batch_data, batch_size=args.batch_size) - logger.info(f"Total batches: {len(data_loader)}") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] with distributed_state.split_between_processes(data_loader) as batch_list: for j in range(len(batch_list)): logger.info(f"Loading batch {j}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:")