fix to work old notation for TE LR in .toml

This commit is contained in:
Kohya S
2024-09-12 12:36:07 +09:00
parent 237317fffd
commit cefe52629e

View File

@@ -788,8 +788,11 @@ class LoRANetwork(torch.nn.Module):
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of two elements # make sure text_encoder_lr as list of two elements
if text_encoder_lr is None or len(text_encoder_lr) == 0: # if float, use the same value for both text encoders
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr] text_encoder_lr = [default_lr, default_lr]
elif isinstance(text_encoder_lr, float):
text_encoder_lr = [text_encoder_lr, text_encoder_lr]
elif len(text_encoder_lr) == 1: elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]