From dde7807b000b304018423802bb3d8e774620c489 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 Jun 2023 22:21:36 +0900 Subject: [PATCH] add rank dropout/module dropout --- networks/lora.py | 88 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 12 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 19fbbbdb..1a665fc4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name @@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout def apply_to(self): self.org_forward = self.org_module.forward @@ -68,18 +80,45 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout if self.dropout: - return ( - self.org_forward(x) - + self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale - ) - else: - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * self.scale class LoRAInfModule(LoRAModule): - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None): - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout) + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) self.org_module_ref = [org_module] # 後から参照できるように self.enabled = True @@ -395,6 +434,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_block_dims = None conv_block_alphas = None + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoder, @@ -403,6 +450,8 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un lora_dim=network_dim, alpha=network_alpha, dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, block_dims=block_dims, @@ -679,6 +728,8 @@ class LoRANetwork(torch.nn.Module): lora_dim=4, alpha=1, dropout=None, + rank_dropout=None, + module_dropout=None, conv_lora_dim=None, conv_alpha=None, block_dims=None, @@ -706,18 +757,22 @@ class LoRANetwork(torch.nn.Module): self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout if modules_dim is not None: print(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}") + print(f"create LoRA network from block_dims") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") print(f"block_dims: {block_dims}") print(f"block_alphas: {block_alphas}") if conv_block_dims is not None: print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}") + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") @@ -764,7 +819,16 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue - lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout) + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) loras.append(lora) return loras, skipped