mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update versions of accelerate and diffusers
This commit is contained in:
@@ -117,7 +117,7 @@ class LoRAModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
@@ -126,7 +126,7 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
@@ -166,7 +166,8 @@ class LoRAModule(torch.nn.Module):
|
||||
self.org_module[0].forward = self.org_forward
|
||||
|
||||
# forward with lora
|
||||
def forward(self, x):
|
||||
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
||||
def forward(self, x, scale=1.0):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
@@ -318,8 +319,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_linear = (
|
||||
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
||||
)
|
||||
is_conv2d = (
|
||||
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
||||
)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
@@ -359,7 +364,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if len(skipped_te) > 0:
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight.")
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
|
||||
# extend U-Net target modules to include Conv2d 3x3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
@@ -368,7 +373,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
if len(skipped_un) > 0:
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight.")
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
|
||||
Reference in New Issue
Block a user