mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet Add log for lora+
This commit is contained in:
@@ -412,32 +412,32 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
default_lr,
|
||||
unet_lora_plus_ratio=None,
|
||||
text_encoder_lora_plus_ratio=None
|
||||
unet_loraplus_ratio=None,
|
||||
text_encoder_loraplus_ratio=None,
|
||||
loraplus_ratio=None
|
||||
):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def assemble_params(loras, lr, lora_plus_ratio):
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
for name, param in lora.named_parameters():
|
||||
if lora_plus_ratio is not None and "lora_up" in name:
|
||||
if ratio is not None and "lora_B" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
# assigned_param_groups = ""
|
||||
# for group in param_groups:
|
||||
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
|
||||
# logger.info(assigned_param_groups)
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * lora_plus_ratio
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
@@ -452,7 +452,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
text_encoder_lora_plus_ratio
|
||||
text_encoder_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
@@ -460,7 +460,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
params = assemble_params(
|
||||
self.unet_loras,
|
||||
default_lr if unet_lr is None else unet_lr,
|
||||
unet_lora_plus_ratio
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user