mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Add drop_keys to drop certain keys from the LoRA network
This commit is contained in:
@@ -456,6 +456,10 @@ def create_network(
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
drop_keys = kwargs.get("drop_keys", None)
|
||||
if drop_keys is not None:
|
||||
drop_keys = drop_keys.split(',')
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
@@ -480,6 +484,7 @@ def create_network(
|
||||
block_alphas=block_alphas,
|
||||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
drop_keys=drop_keys,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -764,6 +769,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
drop_keys: Optional[List[str]] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -784,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.drop_keys = drop_keys
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -801,6 +808,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
if self.drop_keys:
|
||||
print(f"Drop keys: {self.drop_keys}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_unet: bool,
|
||||
@@ -830,6 +840,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if self.drop_keys:
|
||||
for key in self.drop_keys:
|
||||
if key in lora_name:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user