mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix alpha is ignored
This commit is contained in:
@@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class)
|
||||
network = LoRANetwork(
|
||||
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
@@ -331,6 +333,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -348,12 +352,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -381,13 +388,19 @@ class LoRANetwork(torch.nn.Module):
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
|
||||
Reference in New Issue
Block a user