Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-02-01 01:28:21 +08:00
committed by GitHub
parent 2cdaa33147
commit d179a48696

View File

@@ -1846,6 +1846,8 @@ def main(args):
pipe.set_control_nets(control_nets) pipe.set_control_nets(control_nets)
logger.info(f"pipeline on {device} is ready.") logger.info(f"pipeline on {device} is ready.")
distributed_state.wait_for_everyone() distributed_state.wait_for_everyone()
pipes = gather_objects([pipe])
unets = gather_objects([unet])
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
@@ -2498,7 +2500,8 @@ def main(args):
negative_scale = args.negative_scale negative_scale = args.negative_scale
steps = args.steps steps = args.steps
seed = None seed = None
seeds = None if pi == 0:
seeds = None
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
negative_prompt = "" negative_prompt = ""
clip_prompt = None clip_prompt = None
@@ -2578,7 +2581,11 @@ def main(args):
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed if m: # seed
seeds = [int(d) for d in m.group(1).split(",")] if pi > 0 and len(raw_prompts) > 1: #Bypass on 2nd loop for dynamic prompts
continue
logger.info(f"{m}")
seeds = m.group(1).split(",")
seeds = [int(float(d.strip())) for d in seeds]
logger.info(f"seeds: {seeds}") logger.info(f"seeds: {seeds}")
continue continue
@@ -2744,7 +2751,8 @@ def main(args):
if ds_depth_1 is not None: if ds_depth_1 is not None:
if ds_depth_1 < 0: if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3 ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) for unet in unets:
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# override Gradual Latent # override Gradual Latent
if gl_timesteps is not None: if gl_timesteps is not None:
@@ -2768,7 +2776,8 @@ def main(args):
us_strength, us_strength,
us_target_x, us_target_x,
) )
pipe.set_gradual_latent(gradual_latent) for pipe in pipes:
pipe.set_gradual_latent(gradual_latent)
# prepare seed # prepare seed
if seeds is not None: # given in prompt if seeds is not None: # given in prompt
@@ -2776,8 +2785,12 @@ def main(args):
if len(seeds) > 1: if len(seeds) > 1:
seed = seeds.pop(0) seed = seeds.pop(0)
elif len(seeds) == 1: elif len(seeds) == 1:
seed = seeds.pop(0) if seeds[0] == -1:
seeds = None seeds = None
else:
seed = seeds.pop(0)
else: else:
if predefined_seeds is not None: if predefined_seeds is not None:
if len(predefined_seeds) > 0: if len(predefined_seeds) > 0:
@@ -2883,11 +2896,11 @@ def main(args):
batch_separated_list.append(split_into_batches) batch_separated_list.append(split_into_batches)
if distributed_state.num_processes > 1: if distributed_state.num_processes > 1:
templist = [] templist = []
for i in range(distributed_state.num_processes): for i in range(distributed_state.num_processes):
templist.append(batch_separated_list[i :: distributed_state.num_processes]) templist.append(batch_separated_list[i :: distributed_state.num_processes])
batch_separated_list = [] batch_separated_list = []
for sub_batch_list in templist: for sub_batch_list in templist:
batch_separated_list.extend(sub_batch_list) batch_separated_list.extend(sub_batch_list)
distributed_state.wait_for_everyone() distributed_state.wait_for_everyone()
batch_data = gather_object(batch_separated_list) batch_data = gather_object(batch_separated_list)
del extinfo del extinfo