mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add rank dropout/module dropout
This commit is contained in:
@@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module):
|
|||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
dropout=None,
|
||||||
|
rank_dropout=None,
|
||||||
|
module_dropout=None,
|
||||||
|
):
|
||||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
@@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.rank_dropout = rank_dropout
|
||||||
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
@@ -68,18 +80,45 @@ class LoRAModule(torch.nn.Module):
|
|||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
org_forwarded = self.org_forward(x)
|
||||||
|
|
||||||
|
# module dropout
|
||||||
|
if self.module_dropout:
|
||||||
|
if torch.rand(1) < self.module_dropout:
|
||||||
|
return org_forwarded
|
||||||
|
|
||||||
|
lx = self.lora_down(x)
|
||||||
|
|
||||||
|
# normal dropout
|
||||||
if self.dropout:
|
if self.dropout:
|
||||||
return (
|
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||||
self.org_forward(x)
|
|
||||||
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
|
# rank dropout
|
||||||
)
|
if self.rank_dropout:
|
||||||
else:
|
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
if len(lx.size()) == 3:
|
||||||
|
mask = mask.unsqueeze(1) # for Text Encoder
|
||||||
|
elif len(lx.size()) == 4:
|
||||||
|
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||||
|
lx = lx * mask
|
||||||
|
|
||||||
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
|
return org_forwarded + lx * self.multiplier * self.scale
|
||||||
|
|
||||||
|
|
||||||
class LoRAInfModule(LoRAModule):
|
class LoRAInfModule(LoRAModule):
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
def __init__(
|
||||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
|
self,
|
||||||
|
lora_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# no dropout for inference
|
||||||
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||||
|
|
||||||
self.org_module_ref = [org_module] # 後から参照できるように
|
self.org_module_ref = [org_module] # 後から参照できるように
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
@@ -395,6 +434,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
conv_block_dims = None
|
conv_block_dims = None
|
||||||
conv_block_alphas = None
|
conv_block_alphas = None
|
||||||
|
|
||||||
|
# rank/module dropout
|
||||||
|
rank_dropout = kwargs.get("rank_dropout", None)
|
||||||
|
if rank_dropout is not None:
|
||||||
|
rank_dropout = float(rank_dropout)
|
||||||
|
module_dropout = kwargs.get("module_dropout", None)
|
||||||
|
if module_dropout is not None:
|
||||||
|
module_dropout = float(module_dropout)
|
||||||
|
|
||||||
# すごく引数が多いな ( ^ω^)・・・
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
@@ -403,6 +450,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
lora_dim=network_dim,
|
lora_dim=network_dim,
|
||||||
alpha=network_alpha,
|
alpha=network_alpha,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
block_dims=block_dims,
|
block_dims=block_dims,
|
||||||
@@ -679,6 +728,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora_dim=4,
|
lora_dim=4,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
dropout=None,
|
dropout=None,
|
||||||
|
rank_dropout=None,
|
||||||
|
module_dropout=None,
|
||||||
conv_lora_dim=None,
|
conv_lora_dim=None,
|
||||||
conv_alpha=None,
|
conv_alpha=None,
|
||||||
block_dims=None,
|
block_dims=None,
|
||||||
@@ -706,18 +757,22 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.conv_lora_dim = conv_lora_dim
|
self.conv_lora_dim = conv_lora_dim
|
||||||
self.conv_alpha = conv_alpha
|
self.conv_alpha = conv_alpha
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.rank_dropout = rank_dropout
|
||||||
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
print(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
|
print(f"create LoRA network from block_dims")
|
||||||
|
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
print(f"block_dims: {block_dims}")
|
print(f"block_dims: {block_dims}")
|
||||||
print(f"block_alphas: {block_alphas}")
|
print(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
print(f"conv_block_dims: {conv_block_dims}")
|
print(f"conv_block_dims: {conv_block_dims}")
|
||||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
@@ -764,7 +819,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
|
lora = module_class(
|
||||||
|
lora_name,
|
||||||
|
child_module,
|
||||||
|
self.multiplier,
|
||||||
|
dim,
|
||||||
|
alpha,
|
||||||
|
dropout=dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user