add rank dropout/module dropout

This commit is contained in:
Kohya S
2023-06-01 22:21:36 +09:00
parent f8e8df5a04
commit dde7807b00

View File

@@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
@@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -68,18 +80,45 @@ class LoRAModule(torch.nn.Module):
del self.org_module
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout:
return (
self.org_forward(x)
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
)
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
@@ -395,6 +434,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims = None
conv_block_alphas = None
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoder,
@@ -403,6 +450,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
lora_dim=network_dim,
alpha=network_alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
@@ -679,6 +728,8 @@ class LoRANetwork(torch.nn.Module):
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
@@ -706,18 +757,22 @@ class LoRANetwork(torch.nn.Module):
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
@@ -764,7 +819,16 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name)
continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)
return loras, skipped