mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -1846,6 +1846,8 @@ def main(args):
|
|||||||
pipe.set_control_nets(control_nets)
|
pipe.set_control_nets(control_nets)
|
||||||
logger.info(f"pipeline on {device} is ready.")
|
logger.info(f"pipeline on {device} is ready.")
|
||||||
distributed_state.wait_for_everyone()
|
distributed_state.wait_for_everyone()
|
||||||
|
pipes = gather_objects([pipe])
|
||||||
|
unets = gather_objects([unet])
|
||||||
|
|
||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
@@ -2498,7 +2500,8 @@ def main(args):
|
|||||||
negative_scale = args.negative_scale
|
negative_scale = args.negative_scale
|
||||||
steps = args.steps
|
steps = args.steps
|
||||||
seed = None
|
seed = None
|
||||||
seeds = None
|
if pi == 0:
|
||||||
|
seeds = None
|
||||||
strength = 0.8 if args.strength is None else args.strength
|
strength = 0.8 if args.strength is None else args.strength
|
||||||
negative_prompt = ""
|
negative_prompt = ""
|
||||||
clip_prompt = None
|
clip_prompt = None
|
||||||
@@ -2578,7 +2581,11 @@ def main(args):
|
|||||||
|
|
||||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||||
if m: # seed
|
if m: # seed
|
||||||
seeds = [int(d) for d in m.group(1).split(",")]
|
if pi > 0 and len(raw_prompts) > 1: #Bypass on 2nd loop for dynamic prompts
|
||||||
|
continue
|
||||||
|
logger.info(f"{m}")
|
||||||
|
seeds = m.group(1).split(",")
|
||||||
|
seeds = [int(float(d.strip())) for d in seeds]
|
||||||
logger.info(f"seeds: {seeds}")
|
logger.info(f"seeds: {seeds}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -2744,7 +2751,8 @@ def main(args):
|
|||||||
if ds_depth_1 is not None:
|
if ds_depth_1 is not None:
|
||||||
if ds_depth_1 < 0:
|
if ds_depth_1 < 0:
|
||||||
ds_depth_1 = args.ds_depth_1 or 3
|
ds_depth_1 = args.ds_depth_1 or 3
|
||||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
for unet in unets:
|
||||||
|
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||||
|
|
||||||
# override Gradual Latent
|
# override Gradual Latent
|
||||||
if gl_timesteps is not None:
|
if gl_timesteps is not None:
|
||||||
@@ -2768,7 +2776,8 @@ def main(args):
|
|||||||
us_strength,
|
us_strength,
|
||||||
us_target_x,
|
us_target_x,
|
||||||
)
|
)
|
||||||
pipe.set_gradual_latent(gradual_latent)
|
for pipe in pipes:
|
||||||
|
pipe.set_gradual_latent(gradual_latent)
|
||||||
|
|
||||||
# prepare seed
|
# prepare seed
|
||||||
if seeds is not None: # given in prompt
|
if seeds is not None: # given in prompt
|
||||||
@@ -2776,8 +2785,12 @@ def main(args):
|
|||||||
if len(seeds) > 1:
|
if len(seeds) > 1:
|
||||||
seed = seeds.pop(0)
|
seed = seeds.pop(0)
|
||||||
elif len(seeds) == 1:
|
elif len(seeds) == 1:
|
||||||
seed = seeds.pop(0)
|
if seeds[0] == -1:
|
||||||
seeds = None
|
seeds = None
|
||||||
|
else:
|
||||||
|
seed = seeds.pop(0)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if predefined_seeds is not None:
|
if predefined_seeds is not None:
|
||||||
if len(predefined_seeds) > 0:
|
if len(predefined_seeds) > 0:
|
||||||
@@ -2883,11 +2896,11 @@ def main(args):
|
|||||||
batch_separated_list.append(split_into_batches)
|
batch_separated_list.append(split_into_batches)
|
||||||
if distributed_state.num_processes > 1:
|
if distributed_state.num_processes > 1:
|
||||||
templist = []
|
templist = []
|
||||||
for i in range(distributed_state.num_processes):
|
for i in range(distributed_state.num_processes):
|
||||||
templist.append(batch_separated_list[i :: distributed_state.num_processes])
|
templist.append(batch_separated_list[i :: distributed_state.num_processes])
|
||||||
batch_separated_list = []
|
batch_separated_list = []
|
||||||
for sub_batch_list in templist:
|
for sub_batch_list in templist:
|
||||||
batch_separated_list.extend(sub_batch_list)
|
batch_separated_list.extend(sub_batch_list)
|
||||||
distributed_state.wait_for_everyone()
|
distributed_state.wait_for_everyone()
|
||||||
batch_data = gather_object(batch_separated_list)
|
batch_data = gather_object(batch_separated_list)
|
||||||
del extinfo
|
del extinfo
|
||||||
|
|||||||
Reference in New Issue
Block a user