mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -1513,13 +1513,12 @@ def main(args):
|
|||||||
logger.info(f"preferred device: {device}, {distributed_state.is_main_process}")
|
logger.info(f"preferred device: {device}, {distributed_state.is_main_process}")
|
||||||
clean_memory_on_device(device)
|
clean_memory_on_device(device)
|
||||||
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
|
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
|
||||||
for pi in range(distributed_state.num_processes):
|
|
||||||
if pi == distributed_state.local_process_index:
|
logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}")
|
||||||
logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}")
|
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype
|
||||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype
|
)
|
||||||
)
|
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
|
||||||
distributed_state.wait_for_everyone()
|
distributed_state.wait_for_everyone()
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
@@ -2897,8 +2896,14 @@ def main(args):
|
|||||||
res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]]
|
res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]]
|
||||||
for index in res:
|
for index in res:
|
||||||
templist.append(prompt_data_list[index])
|
templist.append(prompt_data_list[index])
|
||||||
|
if distributed_state.num_processes > 1:
|
||||||
|
resorted_list = []
|
||||||
|
for i in range(distributed_state.num_processes):
|
||||||
|
resorted_list.append(templist[i :: distributed_state.num_processes])
|
||||||
|
for list in resorted_list:
|
||||||
|
templist.extend(list)
|
||||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy()
|
split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy()
|
||||||
'''
|
|
||||||
sublist = []
|
sublist = []
|
||||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
||||||
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
||||||
@@ -2917,7 +2922,7 @@ def main(args):
|
|||||||
sublist.reverse()
|
sublist.reverse()
|
||||||
n, m = divmod(len(sublist), distributed_state.num_processes)
|
n, m = divmod(len(sublist), distributed_state.num_processes)
|
||||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
|
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
|
||||||
'''
|
|
||||||
ext_separated_list_of_batches.append(split_into_batches)
|
ext_separated_list_of_batches.append(split_into_batches)
|
||||||
'''
|
'''
|
||||||
logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
|
logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
|
||||||
|
|||||||
Reference in New Issue
Block a user