diff --git a/networks/lora.py b/networks/lora.py index 1699a60f..bbad2978 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -465,6 +465,10 @@ def create_network( conv_block_dims = None conv_block_alphas = None + drop_keys = kwargs.get("drop_keys", None) + if drop_keys is not None: + drop_keys = drop_keys.split(',') + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -489,6 +493,7 @@ def create_network( block_alphas=block_alphas, conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, + drop_keys=drop_keys, varbose=True, is_sdxl=is_sdxl, ) @@ -893,6 +898,7 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, + drop_keys: Optional[List[str]] = None, varbose: Optional[bool] = False, is_sdxl: Optional[bool] = False, ) -> None: @@ -914,6 +920,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.drop_keys = drop_keys self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -941,6 +948,9 @@ class LoRANetwork(torch.nn.Module): f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" ) + if self.drop_keys: + print(f"Drop keys: {self.drop_keys}") + # create module instances def create_modules( is_unet: bool, @@ -970,6 +980,12 @@ class LoRANetwork(torch.nn.Module): lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + if self.drop_keys: + for key in self.drop_keys: + if key in lora_name: + skipped.append(lora_name) + continue + dim = None alpha = None