mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add multiplier, steps range
This commit is contained in:
@@ -33,7 +33,7 @@ TRANSFORMER_MAX_BLOCK_INDEX = None
|
|||||||
|
|
||||||
|
|
||||||
class LLLiteModule(torch.nn.Module):
|
class LLLiteModule(torch.nn.Module):
|
||||||
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None):
|
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||||
@@ -41,6 +41,7 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
self.cond_emb_dim = cond_emb_dim
|
self.cond_emb_dim = cond_emb_dim
|
||||||
self.org_module = [org_module]
|
self.org_module = [org_module]
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
if self.is_conv2d:
|
if self.is_conv2d:
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
@@ -119,6 +120,10 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||||
"""
|
"""
|
||||||
|
if cond_image is None:
|
||||||
|
self.cond_emb = None
|
||||||
|
return
|
||||||
|
|
||||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||||
cx = self.conditioning1(cond_image)
|
cx = self.conditioning1(cond_image)
|
||||||
@@ -141,6 +146,9 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
学習用の便利forward。元のモジュールのforwardを呼び出す
|
学習用の便利forward。元のモジュールのforwardを呼び出す
|
||||||
/ convenient forward for training. call the forward of the original module
|
/ convenient forward for training. call the forward of the original module
|
||||||
"""
|
"""
|
||||||
|
if self.multiplier == 0.0 or self.cond_emb is None:
|
||||||
|
return self.org_forward(x)
|
||||||
|
|
||||||
cx = self.cond_emb
|
cx = self.cond_emb
|
||||||
|
|
||||||
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
||||||
@@ -160,11 +168,13 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
if self.dropout is not None and self.training:
|
if self.dropout is not None and self.training:
|
||||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||||
|
|
||||||
cx = self.up(cx)
|
cx = self.up(cx) * self.multiplier
|
||||||
|
|
||||||
# residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||||
if self.batch_cond_only:
|
if self.batch_cond_only:
|
||||||
cx = torch.zeros_like(x)[1::2] + cx
|
zx = torch.zeros_like(x)
|
||||||
|
zx[1::2] += cx
|
||||||
|
cx = zx
|
||||||
|
|
||||||
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||||
return x
|
return x
|
||||||
@@ -181,6 +191,7 @@ class ControlNetLLLite(torch.nn.Module):
|
|||||||
mlp_dim: int = 16,
|
mlp_dim: int = 16,
|
||||||
dropout: Optional[float] = None,
|
dropout: Optional[float] = None,
|
||||||
varbose: Optional[bool] = False,
|
varbose: Optional[bool] = False,
|
||||||
|
multiplier: Optional[float] = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.unets = [unet]
|
# self.unets = [unet]
|
||||||
@@ -264,6 +275,7 @@ class ControlNetLLLite(torch.nn.Module):
|
|||||||
child_module,
|
child_module,
|
||||||
mlp_dim,
|
mlp_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
multiplier=multiplier,
|
||||||
)
|
)
|
||||||
modules.append(module)
|
modules.append(module)
|
||||||
return modules
|
return modules
|
||||||
@@ -291,6 +303,10 @@ class ControlNetLLLite(torch.nn.Module):
|
|||||||
for module in self.unet_modules:
|
for module in self.unet_modules:
|
||||||
module.set_batch_cond_only(cond_only, zeros)
|
module.set_batch_cond_only(cond_only, zeros)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
for module in self.unet_modules:
|
||||||
|
module.multiplier = multiplier
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|||||||
@@ -661,21 +661,28 @@ class PipelineLike:
|
|||||||
if self.control_nets:
|
if self.control_nets:
|
||||||
# guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
# guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||||
if self.control_net_enabled:
|
if self.control_net_enabled:
|
||||||
for control_net in self.control_nets:
|
for control_net, _ in self.control_nets:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
control_net.set_cond_image(clip_guide_images)
|
control_net.set_cond_image(clip_guide_images)
|
||||||
else:
|
else:
|
||||||
for control_net in self.control_nets:
|
for control_net, _ in self.control_nets:
|
||||||
control_net.set_cond_image(None)
|
control_net.set_cond_image(None)
|
||||||
|
|
||||||
|
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
# # disable control net if ratio is set
|
# disable control net if ratio is set
|
||||||
# if self.control_nets and self.control_net_enabled:
|
if self.control_nets and self.control_net_enabled:
|
||||||
# pass # TODO
|
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
|
||||||
|
if not enabled or ratio >= 1.0:
|
||||||
|
continue
|
||||||
|
if ratio < i / len(timesteps):
|
||||||
|
print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
||||||
|
control_net.set_cond_image(None)
|
||||||
|
each_control_net_enabled[j] = False
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
# TODO Diffusers' ControlNet
|
# TODO Diffusers' ControlNet
|
||||||
@@ -1567,7 +1574,7 @@ def main(args):
|
|||||||
upscaler.to(dtype).to(device)
|
upscaler.to(dtype).to(device)
|
||||||
|
|
||||||
# ControlNetの処理
|
# ControlNetの処理
|
||||||
control_nets: List[ControlNetLLLite] = []
|
control_nets: List[Tuple[ControlNetLLLite, float]] = []
|
||||||
# if args.control_net_models:
|
# if args.control_net_models:
|
||||||
# for i, model in enumerate(args.control_net_models):
|
# for i, model in enumerate(args.control_net_models):
|
||||||
# prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
# prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||||
@@ -1595,12 +1602,19 @@ def main(args):
|
|||||||
break
|
break
|
||||||
assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}"
|
assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}"
|
||||||
|
|
||||||
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim)
|
multiplier = (
|
||||||
|
1.0
|
||||||
|
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
|
||||||
|
else args.control_net_multipliers[i]
|
||||||
|
)
|
||||||
|
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||||
|
|
||||||
|
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier)
|
||||||
control_net.apply_to()
|
control_net.apply_to()
|
||||||
control_net.load_state_dict(state_dict)
|
control_net.load_state_dict(state_dict)
|
||||||
control_net.to(dtype).to(device)
|
control_net.to(dtype).to(device)
|
||||||
control_net.set_batch_cond_only(False, False)
|
control_net.set_batch_cond_only(False, False)
|
||||||
control_nets.append(control_net)
|
control_nets.append((control_net, ratio))
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
print(f"set optimizing: channels last")
|
||||||
@@ -2623,14 +2637,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
# "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
||||||
# )
|
# )
|
||||||
# parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率")
|
parser.add_argument(
|
||||||
# parser.add_argument(
|
"--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率"
|
||||||
# "--control_net_ratios",
|
)
|
||||||
# type=float,
|
parser.add_argument(
|
||||||
# default=None,
|
"--control_net_ratios",
|
||||||
# nargs="*",
|
type=float,
|
||||||
# help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
default=None,
|
||||||
# )
|
nargs="*",
|
||||||
|
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||||
|
)
|
||||||
# # parser.add_argument(
|
# # parser.add_argument(
|
||||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
# )
|
# )
|
||||||
|
|||||||
Reference in New Issue
Block a user