mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add multiplier, steps range
This commit is contained in:
@@ -661,21 +661,28 @@ class PipelineLike:
|
||||
if self.control_nets:
|
||||
# guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
if self.control_net_enabled:
|
||||
for control_net in self.control_nets:
|
||||
for control_net, _ in self.control_nets:
|
||||
with torch.no_grad():
|
||||
control_net.set_cond_image(clip_guide_images)
|
||||
else:
|
||||
for control_net in self.control_nets:
|
||||
for control_net, _ in self.control_nets:
|
||||
control_net.set_cond_image(None)
|
||||
|
||||
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# # disable control net if ratio is set
|
||||
# if self.control_nets and self.control_net_enabled:
|
||||
# pass # TODO
|
||||
# disable control net if ratio is set
|
||||
if self.control_nets and self.control_net_enabled:
|
||||
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
|
||||
if not enabled or ratio >= 1.0:
|
||||
continue
|
||||
if ratio < i / len(timesteps):
|
||||
print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
||||
control_net.set_cond_image(None)
|
||||
each_control_net_enabled[j] = False
|
||||
|
||||
# predict the noise residual
|
||||
# TODO Diffusers' ControlNet
|
||||
@@ -1567,7 +1574,7 @@ def main(args):
|
||||
upscaler.to(dtype).to(device)
|
||||
|
||||
# ControlNetの処理
|
||||
control_nets: List[ControlNetLLLite] = []
|
||||
control_nets: List[Tuple[ControlNetLLLite, float]] = []
|
||||
# if args.control_net_models:
|
||||
# for i, model in enumerate(args.control_net_models):
|
||||
# prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
@@ -1595,12 +1602,19 @@ def main(args):
|
||||
break
|
||||
assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}"
|
||||
|
||||
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim)
|
||||
multiplier = (
|
||||
1.0
|
||||
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
|
||||
else args.control_net_multipliers[i]
|
||||
)
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier)
|
||||
control_net.apply_to()
|
||||
control_net.load_state_dict(state_dict)
|
||||
control_net.to(dtype).to(device)
|
||||
control_net.set_batch_cond_only(False, False)
|
||||
control_nets.append(control_net)
|
||||
control_nets.append((control_net, ratio))
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
@@ -2623,14 +2637,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
# parser.add_argument(
|
||||
# "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
||||
# )
|
||||
# parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率")
|
||||
# parser.add_argument(
|
||||
# "--control_net_ratios",
|
||||
# type=float,
|
||||
# default=None,
|
||||
# nargs="*",
|
||||
# help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control_net_ratios",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
)
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user