diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7948efaa..6d7ad988 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2816,12 +2816,22 @@ def main(args): ), ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] + if len(batch_data) == args.batch_size*distributed_state.num_processes: + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() global_step += 1 @@ -2829,7 +2839,12 @@ def main(args): prompt_index += 1 if len(batch_data) > 0: - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() logger.info("done!")