mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix alpha is ignored
This commit is contained in:
@@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
|
|
||||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class)
|
network = LoRANetwork(
|
||||||
|
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||||
|
)
|
||||||
return network, weights_sd
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
@@ -331,6 +333,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
conv_lora_dim: Optional[int] = None,
|
conv_lora_dim: Optional[int] = None,
|
||||||
conv_alpha: Optional[float] = None,
|
conv_alpha: Optional[float] = None,
|
||||||
module_class: Type[object] = LoRAModule,
|
module_class: Type[object] = LoRAModule,
|
||||||
|
modules_dim: Optional[Dict[str, int]] = None,
|
||||||
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
varbose: Optional[bool] = False,
|
varbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -348,6 +352,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.loraplus_unet_lr_ratio = None
|
self.loraplus_unet_lr_ratio = None
|
||||||
self.loraplus_text_encoder_lr_ratio = None
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
logger.info(f"create LoRA network from weights")
|
||||||
|
else:
|
||||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
@@ -381,6 +388,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
# モジュール指定あり
|
||||||
|
if lora_name in modules_dim:
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha = modules_alpha[lora_name]
|
||||||
|
else:
|
||||||
# 通常、すべて対象とする
|
# 通常、すべて対象とする
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = self.lora_dim
|
dim = self.lora_dim
|
||||||
|
|||||||
Reference in New Issue
Block a user