mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
fix dylora create_modules error
This commit is contained in:
@@ -12,7 +12,9 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from transformers import CLIPTextModel
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -165,7 +167,15 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(
|
||||||
|
multiplier: float,
|
||||||
|
network_dim: Optional[int],
|
||||||
|
network_alpha: Optional[float],
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||||
|
unet,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
if network_alpha is None:
|
if network_alpha is None:
|
||||||
@@ -182,6 +192,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
conv_alpha = 1.0
|
conv_alpha = 1.0
|
||||||
else:
|
else:
|
||||||
conv_alpha = float(conv_alpha)
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
if unit is not None:
|
if unit is not None:
|
||||||
unit = int(unit)
|
unit = int(unit)
|
||||||
else:
|
else:
|
||||||
@@ -307,7 +318,21 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||||
|
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
|
if len(text_encoders) > 1:
|
||||||
|
index = i + 1
|
||||||
|
print(f"create LoRA for Text Encoder {index}")
|
||||||
|
else:
|
||||||
|
index = None
|
||||||
|
print(f"create LoRA for Text Encoder")
|
||||||
|
|
||||||
|
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
|
|
||||||
|
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
|
|||||||
Reference in New Issue
Block a user