fix alpha is ignored

This commit is contained in:
Kohya S
2024-08-10 14:03:59 +09:00
parent 808d2d1f48
commit 358f13f2c9

View File

@@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class)
network = LoRANetwork(
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
return network, weights_sd
@@ -331,6 +333,8 @@ class LoRANetwork(torch.nn.Module):
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -348,6 +352,9 @@ class LoRANetwork(torch.nn.Module):
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
@@ -381,6 +388,12 @@ class LoRANetwork(torch.nn.Module):
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim