Add drop_keys to drop certain keys from the LoRA network

This commit is contained in:
rockerBOO
2023-11-20 01:15:21 -05:00
parent 95ae56bd22
commit c0f0b5dae3

View File

@@ -456,6 +456,10 @@ def create_network(
conv_block_dims = None
conv_block_alphas = None
drop_keys = kwargs.get("drop_keys", None)
if drop_keys is not None:
drop_keys = drop_keys.split(',')
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
@@ -480,6 +484,7 @@ def create_network(
block_alphas=block_alphas,
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
drop_keys=drop_keys,
varbose=True,
)
@@ -764,6 +769,7 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
drop_keys: Optional[List[str]] = None,
varbose: Optional[bool] = False,
) -> None:
"""
@@ -784,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.drop_keys = drop_keys
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -801,6 +808,9 @@ class LoRANetwork(torch.nn.Module):
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
if self.drop_keys:
print(f"Drop keys: {self.drop_keys}")
# create module instances
def create_modules(
is_unet: bool,
@@ -830,6 +840,12 @@ class LoRANetwork(torch.nn.Module):
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if self.drop_keys:
for key in self.drop_keys:
if key in lora_name:
skipped.append(lora_name)
continue
dim = None
alpha = None