diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 632707cb..72425826 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2831,12 +2831,12 @@ def main(args): batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(batch_data)): - if distributed_state.is_main_process: - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:")