Fix emb_dim to work.

This commit is contained in:
kohya-ss
2024-10-29 23:29:50 +09:00
parent c9a1417157
commit b502f58488

View File

@@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
target_replace_modules: List[str], target_replace_modules: List[str],
filter: Optional[str] = None, filter: Optional[str] = None,
default_dim: Optional[int] = None, default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]: ) -> List[LoRAModule]:
prefix = ( prefix = (
self.LORA_PREFIX_SD3 self.LORA_PREFIX_SD3
@@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module):
lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_") lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name: force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue continue
force_incl_conv2d = include_conv2d_if_filter
dim = None dim = None
alpha = None alpha = None
@@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module):
elif self.conv_lora_dim is not None: elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim dim = self.conv_lora_dim
alpha = self.conv_alpha alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if dim is None or dim == 0: if dim is None or dim == 0:
# skipした情報を出力 # skipした情報を出力
@@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module):
], ],
self.emb_dims, self.emb_dims,
): ):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) # x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras) self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")