mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix emb_dim to work.
This commit is contained in:
@@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
target_replace_modules: List[str],
|
target_replace_modules: List[str],
|
||||||
filter: Optional[str] = None,
|
filter: Optional[str] = None,
|
||||||
default_dim: Optional[int] = None,
|
default_dim: Optional[int] = None,
|
||||||
|
include_conv2d_if_filter: bool = False,
|
||||||
) -> List[LoRAModule]:
|
) -> List[LoRAModule]:
|
||||||
prefix = (
|
prefix = (
|
||||||
self.LORA_PREFIX_SD3
|
self.LORA_PREFIX_SD3
|
||||||
@@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||||
lora_name = lora_name.replace(".", "_")
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
if filter is not None and not filter in lora_name:
|
force_incl_conv2d = False
|
||||||
|
if filter is not None:
|
||||||
|
if not filter in lora_name:
|
||||||
continue
|
continue
|
||||||
|
force_incl_conv2d = include_conv2d_if_filter
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
@@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
elif self.conv_lora_dim is not None:
|
elif self.conv_lora_dim is not None:
|
||||||
dim = self.conv_lora_dim
|
dim = self.conv_lora_dim
|
||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
|
elif force_incl_conv2d:
|
||||||
|
# x_embedder
|
||||||
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
|
||||||
if dim is None or dim == 0:
|
if dim is None or dim == 0:
|
||||||
# skipした情報を出力
|
# skipした情報を出力
|
||||||
@@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
self.emb_dims,
|
self.emb_dims,
|
||||||
):
|
):
|
||||||
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
# x_embedder is conv2d, so we need to include it
|
||||||
|
loras, _ = create_modules(
|
||||||
|
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
|
||||||
|
)
|
||||||
|
# if len(loras) > 0:
|
||||||
|
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
|
||||||
self.unet_loras.extend(loras)
|
self.unet_loras.extend(loras)
|
||||||
|
|
||||||
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
||||||
|
|||||||
Reference in New Issue
Block a user