mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add multiplier, steps range
This commit is contained in:
@@ -33,7 +33,7 @@ TRANSFORMER_MAX_BLOCK_INDEX = None
|
||||
|
||||
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
@@ -41,6 +41,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
self.multiplier = multiplier
|
||||
|
||||
if self.is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
@@ -119,6 +120,10 @@ class LLLiteModule(torch.nn.Module):
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
if cond_image is None:
|
||||
self.cond_emb = None
|
||||
return
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
@@ -141,6 +146,9 @@ class LLLiteModule(torch.nn.Module):
|
||||
学習用の便利forward。元のモジュールのforwardを呼び出す
|
||||
/ convenient forward for training. call the forward of the original module
|
||||
"""
|
||||
if self.multiplier == 0.0 or self.cond_emb is None:
|
||||
return self.org_forward(x)
|
||||
|
||||
cx = self.cond_emb
|
||||
|
||||
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
@@ -160,11 +168,13 @@ class LLLiteModule(torch.nn.Module):
|
||||
if self.dropout is not None and self.training:
|
||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||
|
||||
cx = self.up(cx)
|
||||
cx = self.up(cx) * self.multiplier
|
||||
|
||||
# residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||
if self.batch_cond_only:
|
||||
cx = torch.zeros_like(x)[1::2] + cx
|
||||
zx = torch.zeros_like(x)
|
||||
zx[1::2] += cx
|
||||
cx = zx
|
||||
|
||||
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
@@ -181,6 +191,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
mlp_dim: int = 16,
|
||||
dropout: Optional[float] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
multiplier: Optional[float] = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# self.unets = [unet]
|
||||
@@ -264,6 +275,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
child_module,
|
||||
mlp_dim,
|
||||
dropout=dropout,
|
||||
multiplier=multiplier,
|
||||
)
|
||||
modules.append(module)
|
||||
return modules
|
||||
@@ -291,6 +303,10 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
for module in self.unet_modules:
|
||||
module.set_batch_cond_only(cond_only, zeros)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for module in self.unet_modules:
|
||||
module.multiplier = multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
Reference in New Issue
Block a user