Fix default LR, Add overall LoRA+ ratio, Add log

`--loraplus_ratio` added for both TE and UNet
Add log for lora+
This commit is contained in:
rockerBOO
2024-04-08 19:23:02 -04:00
parent 1933ab4b48
commit 75833e84a1
5 changed files with 101 additions and 60 deletions

View File

@@ -412,32 +412,32 @@ class DyLoRANetwork(torch.nn.Module):
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []
def assemble_params(loras, lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
@@ -452,7 +452,7 @@ class DyLoRANetwork(torch.nn.Module):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)
@@ -460,7 +460,7 @@ class DyLoRANetwork(torch.nn.Module):
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

View File

@@ -1040,32 +1040,32 @@ class LoRANetwork(torch.nn.Module):
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []
def assemble_params(loras, lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if lora_plus_ratio is not None and "lora_up" in name:
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
@@ -1080,7 +1080,7 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)
@@ -1099,15 +1099,15 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)
else:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

View File

@@ -1038,32 +1038,32 @@ class LoRANetwork(torch.nn.Module):
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
unet_loraplus_ratio=None,
text_encoder_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []
def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.get_trainable_named_params():
if lora_plus_ratio is not None and "lora_up" in name:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
@@ -1078,7 +1078,7 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)
@@ -1097,15 +1097,15 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)
else:
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)