Fix default_lr being applied

This commit is contained in:
rockerBOO
2024-04-03 12:46:34 -04:00
parent c7691607ea
commit 1933ab4b48
3 changed files with 64 additions and 17 deletions

View File

@@ -407,7 +407,14 @@ class DyLoRANetwork(torch.nn.Module):
"""
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
):
self.requires_grad_(True)
all_params = []
@@ -442,11 +449,19 @@ class DyLoRANetwork(torch.nn.Module):
return params
if self.text_encoder_loras:
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
)
all_params.extend(params)
if self.unet_loras:
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
)
all_params.extend(params)
return all_params