This commit is contained in:
Dave Lage
2026-04-01 13:54:34 +00:00
committed by GitHub

View File

@@ -465,6 +465,10 @@ def create_network(
conv_block_dims = None
conv_block_alphas = None
drop_keys = kwargs.get("drop_keys", None)
if drop_keys is not None:
drop_keys = drop_keys.split(',')
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
@@ -489,6 +493,7 @@ def create_network(
block_alphas=block_alphas,
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
drop_keys=drop_keys,
varbose=True,
is_sdxl=is_sdxl,
)
@@ -893,6 +898,7 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
drop_keys: Optional[List[str]] = None,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
) -> None:
@@ -914,6 +920,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.drop_keys = drop_keys
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -941,6 +948,9 @@ class LoRANetwork(torch.nn.Module):
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
if self.drop_keys:
print(f"Drop keys: {self.drop_keys}")
# create module instances
def create_modules(
is_unet: bool,
@@ -970,6 +980,12 @@ class LoRANetwork(torch.nn.Module):
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if self.drop_keys:
for key in self.drop_keys:
if key in lora_name:
skipped.append(lora_name)
continue
dim = None
alpha = None