From 358f13f2c92a04fb524006f124fc029a9edb0eaf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 14:03:59 +0900 Subject: [PATCH] fix alpha is ignored --- networks/lora_flux.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 141137b4..332a73d9 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh module_class = LoRAInfModule if for_inference else LoRAModule - network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + network = LoRANetwork( + text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) return network, weights_sd @@ -331,6 +333,8 @@ class LoRANetwork(torch.nn.Module): conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -348,12 +352,15 @@ class LoRANetwork(torch.nn.Module): self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -381,13 +388,19 @@ class LoRANetwork(torch.nn.Module): dim = None alpha = None - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力