add rank dropout/module dropout

This commit is contained in:
Kohya S
2023-06-01 22:21:36 +09:00
parent f8e8df5a04
commit dde7807b00

View File

@@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""if alpha == 0 or None, alpha is rank (no scaling).""" """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
@@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.dropout = dropout self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
@@ -68,18 +80,45 @@ class LoRAModule(torch.nn.Module):
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout: if self.dropout:
return ( lx = torch.nn.functional.dropout(lx, p=self.dropout)
self.org_forward(x)
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale # rank dropout
) if self.rank_dropout:
else: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * self.scale
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): def __init__(
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout) self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True self.enabled = True
@@ -395,6 +434,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims = None conv_block_dims = None
conv_block_alphas = None conv_block_alphas = None
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# すごく引数が多いな ( ^ω^)・・・ # すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork( network = LoRANetwork(
text_encoder, text_encoder,
@@ -403,6 +450,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
lora_dim=network_dim, lora_dim=network_dim,
alpha=network_alpha, alpha=network_alpha,
dropout=dropout, dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,
conv_alpha=conv_alpha, conv_alpha=conv_alpha,
block_dims=block_dims, block_dims=block_dims,
@@ -679,6 +728,8 @@ class LoRANetwork(torch.nn.Module):
lora_dim=4, lora_dim=4,
alpha=1, alpha=1,
dropout=None, dropout=None,
rank_dropout=None,
module_dropout=None,
conv_lora_dim=None, conv_lora_dim=None,
conv_alpha=None, conv_alpha=None,
block_dims=None, block_dims=None,
@@ -706,18 +757,22 @@ class LoRANetwork(torch.nn.Module):
self.conv_lora_dim = conv_lora_dim self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha self.conv_alpha = conv_alpha
self.dropout = dropout self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}") print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}") print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}") print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}") print(f"conv_block_alphas: {conv_block_alphas}")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}") print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
@@ -764,7 +819,16 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name) skipped.append(lora_name)
continue continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout) lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora) loras.append(lora)
return loras, skipped return loras, skipped