diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d540c221..dd267de0 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -788,8 +788,11 @@ class LoRANetwork(torch.nn.Module): def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements - if text_encoder_lr is None or len(text_encoder_lr) == 0: + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float): + text_encoder_lr = [text_encoder_lr, text_encoder_lr] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]