From f99fe281cbb6519b7b5f1199c570d496ad4df474 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:38:26 -0400 Subject: [PATCH 1/5] Add LoRA+ support --- library/train_util.py | 2 ++ networks/dylora.py | 45 ++++++++++++++++++++++++++---------- networks/lora.py | 54 ++++++++++++++++++++++++++++--------------- train_network.py | 2 +- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb..4e5ab737 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f3345..a73ade8b 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,48 @@ class DyLoRANetwork(torch.nn.Module): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 948b30b0..8d761977 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,21 +1035,43 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1063,21 +1085,15 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/train_network.py b/train_network.py index e0fa6945..ba0c124d 100644 --- a/train_network.py +++ b/train_network.py @@ -339,7 +339,7 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" From c7691607ea1647864b5149c98434a27f23386c65 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:43:04 -0400 Subject: [PATCH 2/5] Add LoRA-FA for LoRA+ --- networks/lora_fa.py | 58 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce..fcc503e8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,22 +1033,43 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras: List[LoRAModule]): - params = [] + def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) + for name, param in lora.get_trainable_named_params(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1083,15 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params @@ -1093,6 +1108,9 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() + def get_trainable_named_params(self): + return self.named_parameters() + def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 1933ab4b4848b1f8b578c10f25bd050f5e246ac0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 3 Apr 2024 12:46:34 -0400 Subject: [PATCH 3/5] Fix default_lr being applied --- networks/dylora.py | 21 ++++++++++++++++++--- networks/lora.py | 30 +++++++++++++++++++++++------- networks/lora_fa.py | 30 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index a73ade8b..edc3e222 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -407,7 +407,14 @@ class DyLoRANetwork(torch.nn.Module): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -442,11 +449,19 @@ class DyLoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 8d761977..e082941e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,7 +1035,14 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1070,7 +1077,11 @@ class LoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1085,14 +1096,19 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fcc503e8..3f6774dd 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,7 +1033,14 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1068,7 +1075,11 @@ class LoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1083,14 +1094,19 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params From 75833e84a1c7e3c2fb0a9e3ce0fe3d8c1758a012 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 8 Apr 2024 19:23:02 -0400 Subject: [PATCH 4/5] Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ --- library/train_util.py | 1 + networks/dylora.py | 24 ++++++------- networks/lora.py | 28 ++++++++-------- networks/lora_fa.py | 30 ++++++++--------- train_network.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 60 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e5ab737..7c2bf693 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") diff --git a/networks/dylora.py b/networks/dylora.py index edc3e222..dc5c7cb3 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,32 +412,32 @@ class DyLoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_B" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -452,7 +452,7 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -460,7 +460,7 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.unet_loras, default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index e082941e..6cb05bcb 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,32 +1040,32 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1080,7 +1080,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1099,15 +1099,15 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 3f6774dd..2eff86d6 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,32 +1038,32 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: - for name, param in lora.get_trainable_named_params(): - if lora_plus_ratio is not None and "lora_up" in name: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1078,7 +1078,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1097,15 +1097,15 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index ba0c124d..43226fc4 100644 --- a/train_network.py +++ b/train_network.py @@ -66,14 +66,61 @@ class NetworkTrainer: lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: + if len(lrs) > 4: + idx = 0 + if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -81,18 +128,6 @@ class NetworkTrainer: logs["lr/d*lr"] = ( lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] ) - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) return logs @@ -339,7 +374,7 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -348,6 +383,11 @@ class NetworkTrainer: optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers From 68467bdf4d76ba2c57289209b0ffd6ba599e2080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 11 Apr 2024 17:33:19 -0400 Subject: [PATCH 5/5] Fix unset or invalid LR from making a param_group --- networks/dylora.py | 4 ++-- networks/lora.py | 5 +++-- networks/lora_fa.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index dc5c7cb3..0546fc7a 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,8 +412,8 @@ class DyLoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -441,7 +441,7 @@ class DyLoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) diff --git a/networks/lora.py b/networks/lora.py index 6cb05bcb..d74608fe 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,8 +1040,8 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1069,7 +1069,8 @@ class LoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + print("NO LR skipping!") continue params.append(param_data) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 2eff86d6..9a608118 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,8 +1038,8 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1067,7 +1067,7 @@ class LoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data)