Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-24 19:28:38 +08:00
committed by GitHub
parent 0770c7ba1b
commit e5cf6b6d06

View File

@@ -2816,12 +2816,22 @@ def main(args):
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
process_batch(batch_data, highres_fix)
batch_data_split = np.array_split(batch_data, distributed_state.num_processes)
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}")
prev_image = process_batch(batch_list, highres_fix)[0]
accelerator.wait_for_everyone()
batch_data.clear()
batch_data.append(b1)
if len(batch_data) == args.batch_size:
prev_image = process_batch(batch_data, highres_fix)[0]
if len(batch_data) == args.batch_size*distributed_state.num_processes:
batch_data_split = np.array_split(batch_data, distributed_state.num_processes)
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}")
prev_image = process_batch(batch_list, highres_fix)[0]
accelerator.wait_for_everyone()
batch_data.clear()
global_step += 1
@@ -2829,7 +2839,12 @@ def main(args):
prompt_index += 1
if len(batch_data) > 0:
process_batch(batch_data, highres_fix)
batch_data_split = np.array_split(batch_data, distributed_state.num_processes)
with torch.no_grad():
with distributed_state.split_between_processes(batch_data_split) as batch_list:
logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}")
prev_image = process_batch(batch_list, highres_fix)[0]
accelerator.wait_for_everyone()
batch_data.clear()
logger.info("done!")