From e5cf6b6d065385a6ff2bd4ab601243ad8c134d54 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 19:28:38 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7948efaa..6d7ad988 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2816,12 +2816,22 @@ def main(args): ), ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] + if len(batch_data) == args.batch_size*distributed_state.num_processes: + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() global_step += 1 @@ -2829,7 +2839,12 @@ def main(args): prompt_index += 1 if len(batch_data) > 0: - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() logger.info("done!")