From c0f0b5dae37efbd25e598561d5787b0537ef464e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 20 Nov 2023 01:15:21 -0500 Subject: [PATCH] Add drop_keys to drop certain keys from the LoRA network --- networks/lora.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/networks/lora.py b/networks/lora.py index 0c75cd42..0962198b 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -456,6 +456,10 @@ def create_network( conv_block_dims = None conv_block_alphas = None + drop_keys = kwargs.get("drop_keys", None) + if drop_keys is not None: + drop_keys = drop_keys.split(',') + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -480,6 +484,7 @@ def create_network( block_alphas=block_alphas, conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, + drop_keys=drop_keys, varbose=True, ) @@ -764,6 +769,7 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, + drop_keys: Optional[List[str]] = None, varbose: Optional[bool] = False, ) -> None: """ @@ -784,6 +790,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.drop_keys = drop_keys if modules_dim is not None: print(f"create LoRA network from weights") @@ -801,6 +808,9 @@ class LoRANetwork(torch.nn.Module): if self.conv_lora_dim is not None: print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if self.drop_keys: + print(f"Drop keys: {self.drop_keys}") + # create module instances def create_modules( is_unet: bool, @@ -830,6 +840,12 @@ class LoRANetwork(torch.nn.Module): lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + if self.drop_keys: + for key in self.drop_keys: + if key in lora_name: + skipped.append(lora_name) + continue + dim = None alpha = None