Fix unset or invalid LR from making a param_group

This commit is contained in:
rockerBOO
2024-04-11 17:33:19 -04:00
parent 75833e84a1
commit 68467bdf4d
3 changed files with 7 additions and 6 deletions

View File

@@ -412,8 +412,8 @@ class DyLoRANetwork(torch.nn.Module):
text_encoder_lr, text_encoder_lr,
unet_lr, unet_lr,
default_lr, default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None, text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None loraplus_ratio=None
): ):
self.requires_grad_(True) self.requires_grad_(True)
@@ -441,7 +441,7 @@ class DyLoRANetwork(torch.nn.Module):
else: else:
param_data["lr"] = lr param_data["lr"] = lr
if ("lr" in param_data) and (param_data["lr"] == 0): if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue continue
params.append(param_data) params.append(param_data)

View File

@@ -1040,8 +1040,8 @@ class LoRANetwork(torch.nn.Module):
text_encoder_lr, text_encoder_lr,
unet_lr, unet_lr,
default_lr, default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None, text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None loraplus_ratio=None
): ):
self.requires_grad_(True) self.requires_grad_(True)
@@ -1069,7 +1069,8 @@ class LoRANetwork(torch.nn.Module):
else: else:
param_data["lr"] = lr param_data["lr"] = lr
if ("lr" in param_data) and (param_data["lr"] == 0): if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
print("NO LR skipping!")
continue continue
params.append(param_data) params.append(param_data)

View File

@@ -1038,8 +1038,8 @@ class LoRANetwork(torch.nn.Module):
text_encoder_lr, text_encoder_lr,
unet_lr, unet_lr,
default_lr, default_lr,
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None, text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None loraplus_ratio=None
): ):
self.requires_grad_(True) self.requires_grad_(True)
@@ -1067,7 +1067,7 @@ class LoRANetwork(torch.nn.Module):
else: else:
param_data["lr"] = lr param_data["lr"] = lr
if ("lr" in param_data) and (param_data["lr"] == 0): if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue continue
params.append(param_data) params.append(param_data)